FFT和NTT
FFT
原理
参考以下文章:
\(\omega_n\)表示n次单位根。FFT其实就是计算如下矩阵乘法: \[ \begin{bmatrix} 1&1&1&\dots&1\\ 1&(\omega_n^1)^1&(\omega_n^1)^2&\dots&(\omega_n^1)^{n-1}\\ 1&(\omega_n^2)^1&(\omega_n^2)^2&\dots&(\omega_n^2)^{n-1}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&(\omega_n^{n-1})^1&(\omega_n^{n-1})^2&\dots&(\omega_n^{n-1})^{n-1} \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1} \end{bmatrix} = \begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1} \end{bmatrix} \] 其中\(a_i\)是多项式的系数表示,\(y_i\)是多项式的点值表示。IFFT即点值表示乘以上面方阵的逆矩阵,求出系数表示,其中逆矩阵如下。(注意逆矩阵前面乘了\(\frac{1}{n}\),因此IFFT做完后要除以\(n\))。 \[ \frac{1}{n} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & (\overline{\omega_n^1})^1 & (\overline{\omega_n^1})^2 & \cdots & (\overline{\omega_n^1})^{n-1} \\ 1 & (\overline{\omega_n^2})^1 & (\overline{\omega_n^2})^2 & \cdots & (\overline{\omega_n^2})^{n-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & (\overline{\omega_n^{n-1}})^1 & (\overline{\omega_n^{n-1}})^2 & \cdots & (\overline{\omega_n^{n-1}})^{n-1} \end{bmatrix} \]
代码
import cmath
import numpy as np
from numpy.polynomial import Polynomial
def FFT(P, inv=False):
"""
P: [p_0, p_1, ..., p_{n-1}]
"""
n = len(P)
assert n & (n - 1) == 0 # n is a power of 2
if n == 1:
return P
Pe = P[::2]
Po = P[1::2]
ye = FFT(Pe, inv)
yo = FFT(Po, inv)
w_n = cmath.exp((-2j if inv else 2j) * cmath.pi / n)
w = 1
y = [0] * n
for k in range(n // 2):
y[k] = ye[k] + w * yo[k]
y[k + n // 2] = ye[k] - w * yo[k]
w *= w_n
return y
def polynomial_multiply(a, b):
# 计算最终长度
n = len(a)
m = len(b)
size = 1
while size < n + m - 1:
size <<= 1
# 补零
a += [0] * (size - n)
b += [0] * (size - m)
# FFT变换
a_fft = FFT(a)
b_fft = FFT(b)
# 点值相乘
c_fft = [a_fft[i] * b_fft[i] for i in range(size)]
# IFFT变换
c = FFT(c_fft, True)
c = [round(x.real / size) for x in c] # 注意IFFT后要除以size
return c[:n + m - 1]
def numpy_polynomial_multiply(a_coe, b_coe):
p1 = Polynomial(a_coe)
p2 = Polynomial(b_coe)
product = p1 * p2
return product.coef
def main():
a = [1, 2, 3]
b = [4, 5, 6, 7]
result = polynomial_multiply(a, b)
print("FFT:", result)
# [4, 13, 28, 34, 32, 21]
numpy_result = numpy_polynomial_multiply(a, b)
print("NumPy:", np.round(numpy_result).astype(int))
if __name__ == "__main__":
main()
NTT
原理
参考以下文章:
就是把n次单位根替换为了原根。
代码
M = 998244353 # 常用NTT模数,998244353 = 2 ^ 23 * 119 + 1
G = 3 # 原根
G_INV = 332748118
def NTT(P, inv=False):
"""
P: [p_0, p_1, ..., p_{n-1}]
"""
n = len(P)
assert n & (n - 1) == 0 # n is a power of 2
if n == 1:
return P
Pe = P[::2]
Po = P[1::2]
ye = NTT(Pe, inv)
yo = NTT(Po, inv)
w_n = pow(G_INV if inv else G, (M - 1) // n, M)
w = 1
y = [0] * n
for k in range(n // 2):
y[k] = (ye[k] + w * yo[k]) % M
y[k + n // 2] = (ye[k] - w * yo[k]) % M
w = w * w_n % M
return y
def polynomial_multiply_ntt(a, b):
# 计算最终长度
n = len(a)
m = len(b)
size = 1
while size < n + m - 1:
size <<= 1
# 补零
a += [0] * (size - n)
b += [0] * (size - m)
# NTT变换
a_ntt = NTT(a)
b_ntt = NTT(b)
# 点值相乘
c_ntt = [a_ntt[i] * b_ntt[i] % M for i in range(size)]
# INTT变换
c = NTT(c_ntt, True)
c = [x * pow(size, -1, M) % M for x in c] # 除以size
return c[:n + m - 1]
def main():
a = [1, 2, 3]
b = [4, 5, 6, 7]
result = polynomial_multiply_ntt(a, b)
print(result)
if __name__ == "__main__":
main()
以下是使用sage计算多项式乘法,用于验证上述代码。
from sage.all import *
M = 998244353
a_coe = [1, 2, 3]
b_coe = [4, 5, 6, 7]
R = PolynomialRing(GF(M), 'x')
a = R(a_coe)
b = R(b_coe)
product = a * b
print(list(product))