FFT和NTT
FFT
参考以下文章:
本质
设多项式\(P(x)=a_{n-1}x^{n-1}+\cdots+a_1x+a_0\),\(a_i\)即是多项式的系数表示。设\(y_i\)是多项式的点值表示,\(\omega_n\)表示n次单位根(\(e^{\frac{2\pi i}{n}}\))。FFT就是将\(\omega_n\)的\(0\sim n-1\)次方依次代入多项式,即如下矩阵乘法: \[ \begin{bmatrix} 1&1&1&\dots&1\\ 1&(\omega_n^1)^1&(\omega_n^1)^2&\dots&(\omega_n^1)^{n-1}\\ 1&(\omega_n^2)^1&(\omega_n^2)^2&\dots&(\omega_n^2)^{n-1}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&(\omega_n^{n-1})^1&(\omega_n^{n-1})^2&\dots&(\omega_n^{n-1})^{n-1} \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1} \end{bmatrix} = \begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1} \end{bmatrix} \] 或写成: \[ y_k=\sum_{i=0}^{n-1}a_i\omega_n^{ki} \] IFFT即点值表示乘以上面方阵的逆矩阵,求出系数表示,其中逆矩阵如下。(注意逆矩阵前面乘了\(\frac{1}{n}\),因此IFFT做完后要除以\(n\))。 \[ \frac{1}{n} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & (\overline{\omega_n^1})^1 & (\overline{\omega_n^1})^2 & \cdots & (\overline{\omega_n^1})^{n-1} \\ 1 & (\overline{\omega_n^2})^1 & (\overline{\omega_n^2})^2 & \cdots & (\overline{\omega_n^2})^{n-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & (\overline{\omega_n^{n-1}})^1 & (\overline{\omega_n^{n-1}})^2 & \cdots & (\overline{\omega_n^{n-1}})^{n-1} \end{bmatrix} \]
算法
显然直接计算矩阵乘法复杂度为\(O(n^2)\),下面给出\(O(n\log n)\)的算法。
将\(P(x)\)按次数的奇偶分成两部分: \[ \begin{align} P(x)=&(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})\\ +&(a_1x+a_3x^3+\cdots+a_{n-1}x^{n-1}) \end{align} \] 记 \[ \begin{align} P_e(x)&=a_0+a_2x+\cdots+a_{n-2}x^{\frac{n-2}{2}}\\ P_o(x)&=a_1+a_3x+\cdots+a_{n-1}x^{\frac{n-2}{2}} \end{align} \] 则 \[ P(x)=P_e(x^2)+xP_o(x^2) \] 代入\(x=\omega_n^k\)(\(0\le k<\frac{n}{2}\))可得 \[ \begin{align} P(\omega_n^k)&=P_e(\omega_n^{2k})+\omega_n^kP_o(\omega_n^{2k})\\ &=P_e(\omega_{\frac{n}{2}}^k)+\omega_n^kP_o(\omega_{\frac{n}{2}}^k) \end{align} \] 再代入\(x=\omega_n^{k+\frac{n}{2}}\)(\(0\le k<\frac{n}{2}\))可得 \[ \begin{align} P(\omega_n^{k+\frac{n}{2}})&=P_e(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}P_o(\omega_n^{2k+n})\\ &=P_e(\omega_n^{2k})+\omega_n^{k+\frac{n}{2}}P_o(\omega_n^{2k})\\ &=P_e(\omega_{\frac{n}{2}}^k)-\omega_n^kP_o(\omega_{\frac{n}{2}}^k) \end{align} \] > 注意旋转因子(Twiddle Factor)为\(\omega_n^k\)(\(0\le k<\frac{n}{2}\)),和\(P_e,P_o\)中的自变量下标不一样。
所谓蝶形单元即上面推导的最后一行,一个是\(P_e+\omega_n^kP_o\),一个是\(P_e-\omega_n^kP_o\)。
因此,递归如下:
- 要求\(P(\omega_n^k)\)和\(P(\omega_n^{k+\frac{n}{2}})\),\(0\le k<\frac{n}{2}\),只需求\(P_e(\omega_{\frac{n}{2}}^k)\)和\(P_o(\omega_{\frac{n}{2}}^k)\);
- 要求\(P_e(\omega_{\frac{n}{2}}^k)\)和\(P_e(\omega_{\frac{n}{2}}^{k+\frac{n}{4}})\),\(0\le k<\frac{n}{4}\),只需求\(P_e(\omega_{\frac{n}{4}}^k)\)和\(P_o(\omega_{\frac{n}{4}}^k)\),\(P_o(w_{\frac{n}{2}}^k)\)同理;
- ……
- 要求\(P_e(\omega_2^k)\)和\(P_e(\omega_2^{k+1})\),\(0\le k<1\),只需求\(P_e(w_1^k)\)和\(P_o(\omega_1^k)\);
- \(P_e(w_1^k)\)和\(P_o(\omega_1^k)\)即\(P_e(1)\)和\(P_o(1)\)直接返回。
代码
import cmath
import numpy as np
from numpy.polynomial import Polynomial
def FFT(P, inv=False):
"""
P: [p_0, p_1, ..., p_{n-1}]
"""
n = len(P)
assert n & (n - 1) == 0 # n是2的幂
if n == 1:
return P
Pe = P[::2] # 取偶次项 f0
Po = P[1::2] # 取奇次项 f1
ye = FFT(Pe, inv)
yo = FFT(Po, inv)
w_n = cmath.exp((-2j if inv else 2j) * cmath.pi / n) # n次单位根
w = 1 # 用于迭代
y = [0] * n
for k in range(n // 2): # 0 <= k < n/2
y[k] = ye[k] + w * yo[k] # f0 + w^k f1
y[k + n // 2] = ye[k] - w * yo[k] # f0 - w^k f1
w *= w_n
return y
def polynomial_multiply(a, b):
# 计算最终长度(大于等于n+m+1的最小的2的幂)
n = len(a)
m = len(b)
size = 1
while size < n + m - 1:
size <<= 1
# 补零
a += [0] * (size - n)
b += [0] * (size - m)
# FFT变换
a_fft = FFT(a)
b_fft = FFT(b)
# 点值相乘
c_fft = [a_fft[i] * b_fft[i] for i in range(size)]
# IFFT变换
c = FFT(c_fft, True)
c = [round(x.real / size) for x in c] # 注意IFFT后要除以size
return c[:n + m - 1]
def numpy_polynomial_multiply(a_coe, b_coe):
# 创建多项式对象
p1 = Polynomial(a_coe)
p2 = Polynomial(b_coe)
# 多项式乘法
product = p1 * p2
# 返回乘积的系数(从低次到高次)
return product.coef
def main():
# a = [1, 2, 3]
# b = [4, 5, 6, 7]
a = [27, 0, -33, 61, 95, -81, 65]
b = [-25, 63, -37, -10, 55, 67, -72, 1]
result = polynomial_multiply(a, b)
print("FFT:", result)
# [4, 13, 28, 34, 32, 21]
numpy_result = numpy_polynomial_multiply(a, b)
print("NumPy:", np.round(numpy_result).astype(int))
if __name__ == "__main__":
main()
注意事项
标准FFT实现使用的n次单位根是\(\omega_n=e^{\frac{-2\pi i}{n}}\),和前面的\(\omega_n\)差一个负号,所以上面的FFT函数实际上相当于标准IFFT,验证如下:
P = list(range(4)) Y1 = FFT(P) print(Y1) # [6, (-2-2j), -2, (-1.9999999999999998+2j)] Y2 = np.fft.ifft(P) * 4 print(Y2) # [ 6.+0.j -2.-2.j -2.+0.j -2.+2.j]
优化
上面的软件实现每次递归都会占用很多内存来存储\(P_e\)和\(P_o\),而我们更希望整个操作都在同一个数组里实现,不增加额外的存储空间。以8个元素为例,我们来模拟软件递归时进行的二分。
- (A)原系数:\((a_0\ a_1\ a_2\ a_3\ a_4\ a_5\ a_6\ a_7)\)。
- (B)二分一次:\((a_0\ a_2\ a_4\ a_6)(a_1\ a_3\ a_5\ a_7)\)。
- (C)二分两次:\((a_0\ a_4)(a_2\ a_6)(a_1\ a_5)(a_3\ a_7)\)。
- (D)二分三次:\((a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7)\)。
按上述分法,把\(P_e\)存储在数组前面,\(P_o\)存储在后面。以第一次二分为例,A中的\(a_0,a_4\)是用B中的\(a_0,a_1\)蝶形计算的,它们在数组中的位置正好相同。因此如果已经得到了B,那么只需用B的\(a_0,a_1\)计算出A中的\(a_0,a_5\)然后覆写到数组同样的位置即可,其余同理。这样整个操作都在同一个数组中进行,无需额外的存储,而且数据的访问很规整。
按上述实现,输入的数组D应该按照下标为04261537的顺序给出,这正好是01234567的二进制位逆序,\(O(n)\)时间内进行逆序置换的算法参考这里。所以先对输入进行一次位逆序置换,然后再使用上述FFT,就能得到正确的输出;如果输入没有置换,那么FFT得到的输出就是位逆序的,需要进行置换才能得到正确结果。
NTT
参考以下文章:
原理
就是把n次单位根\(w_n\)替换为了原根\(g_n=g^{\frac{p-1}{n}}\)。
代码
M = 998244353 # 常用NTT模数,998244353 = 2 ^ 23 * 119 + 1
G = 3 # 原根
G_INV = 332748118
def NTT(P, inv=False):
"""
P: [p_0, p_1, ..., p_{n-1}]
"""
n = len(P)
assert n & (n - 1) == 0 # n is a power of 2
if n == 1:
return P
Pe = P[::2]
Po = P[1::2]
ye = NTT(Pe, inv)
yo = NTT(Po, inv)
w_n = pow(G_INV if inv else G, (M - 1) // n, M)
w = 1
y = [0] * n
for k in range(n // 2):
y[k] = (ye[k] + w * yo[k]) % M
y[k + n // 2] = (ye[k] - w * yo[k]) % M
w = w * w_n % M
return y
def polynomial_multiply_ntt(a, b):
# 计算最终长度
n = len(a)
m = len(b)
size = 1
while size < n + m - 1:
size <<= 1
# 补零
a += [0] * (size - n)
b += [0] * (size - m)
# NTT变换
a_ntt = NTT(a)
b_ntt = NTT(b)
# 点值相乘
c_ntt = [a_ntt[i] * b_ntt[i] % M for i in range(size)]
# INTT变换
c = NTT(c_ntt, True)
c = [x * pow(size, -1, M) % M for x in c] # 除以size
return c[:n + m - 1]
def main():
a = [1, 2, 3]
b = [4, 5, 6, 7]
result = polynomial_multiply_ntt(a, b)
print(result)
if __name__ == "__main__":
main()
以下是使用sage计算多项式乘法,用于验证上述代码。
from sage.all import *
M = 998244353
a_coe = [1, 2, 3]
b_coe = [4, 5, 6, 7]
R = PolynomialRing(GF(M), 'x')
a = R(a_coe)
b = R(b_coe)
product = a * b
print(list(product))
4步NTT
原理
回顾NTT的计算公式: \[ y_k=\sum_{n=0}^{N-1}a_n\omega^{kn} \] 现在将长度为\(N\)的一维数组\([a_0,a_1,\cdots,a_{N-1}]\)看成是\(R\)行\(C\)列的二维数组(\(N=R\cdot C\)),设输入输出矩阵的索引为: \[ \begin{align} n&=i\cdot C+j\\ k&=s\cdot R+r \end{align} \] 即\(i\)和\(r\)表示行,\(j\)和\(s\)表示列。代入NTT的公式可得: \[ \begin{align} y[sR+r]&=\sum_{j=0}^{C-1}\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{(iC+j)(sR+r)}\\ &=\sum_{j=0}^{C-1}\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{iCr}\cdot\omega^{jr}\cdot\omega^{sRj}\\ &=\sum_{j=0}^{C-1}\omega^{sRj}\cdot\left[\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{iCr}\right]\cdot\omega^{jr} \end{align} \] 写成二维数组的形式为: \[ y[r][s]=\sum_{j=0}^{C-1}\omega^{sRj}\cdot\left[\sum_{i=0}^{R-1}a[i][j]\cdot\omega^{iCr}\right]\cdot\omega^{jr} \]
进行如下说明:
二维数组\(arr[\cdot][\cdot]\)的第1个中括号表示行索引,第2个中括号表示列索引。
输入数组\(a[i][j]\)在内存中是按行优先存储的。
索引\(n\)是行优先,\(k\)是列优先。这样\(\omega^{kn}\)才会有一项是\(\omega^{RC}\),易知\(\omega^{RC}=1\),所以上面省略了这一项。
因此可以通过以下4步来计算NTT:
对二维数组的每一列做长度为\(R\)的NTT。旋转因子为\(\omega^C\),计算完成后写回原来的列。对应上面公式中的 \[ \sum_{i=0}^{R-1}a[i][j]\cdot\omega^{iCr} \]
对二维数组的每个元素乘以\(\omega^{ij}\)(\(i\)和\(j\)为该元素的行列号)。对应上面公式中的\(\omega^{jr}\)。
对二维数组的每一行做长度为\(C\)的NTT。旋转因子为\(\omega^R\),计算完成后写回原来的行。即上面公式中的 \[ \sum_{j=0}^{C-1}\omega^{sRj}\cdot[] \]
对二维数组进行转置。
为什么要转置?经过前3步的处理,\(y[r][s]\)的值确实已经存到了二维数组的第\(r\)行第\(s\)列。但是输出二维数组的索引是\(k=s\cdot R+r\),即\(y[0]=y[0\cdot R+0]=y[0][0]\)、\(y[1]=y[0\cdot R+1]=y[1][0]\)、\(y[2]=y[0\cdot R+2]=y[2][0]\)、……。因此正确的输出顺序应该是按列优先输出,即\(y[0][0],y[1][0],\cdots\),所以需要转置。
代码
def four_step_ntt(P, R, C):
assert len(P) == R * C
P_matrix = np.array(P).reshape((R, C)) # row-major
# Step 1: NTT on each column (length R)
# w1 = pow(G, (M - 1) // (R * C) * C, M)
for j in range(C):
col = [P_matrix[i][j] for i in range(R)]
col_ntt = NTT(col)
for i in range(R):
P_matrix[i][j] = col_ntt[i]
# Step 2: Multiply by twiddle factors
w = pow(G, (M - 1) // (R * C), M)
for i in range(R):
for j in range(C):
twiddle = pow(w, i * j, M)
P_matrix[i][j] = P_matrix[i][j] * twiddle % M
# Step 3: NTT on each row (length C)
# w2 = pow(G, (M - 1) // (R * C) * R, M)
for i in range(R):
row = P_matrix[i]
row_ntt = NTT(row.tolist())
P_matrix[i] = row_ntt
# Step 4 skipped (no transpose)
return P_matrix.flatten().tolist()
# Test input
P = [i for i in range(16)]
N = len(P)
R, C = 4, 4 # R * C = N
# Perform both NTT and 4-step NTT
ntt_result = NTT(P)
print(ntt_result)
# [120, 16886715, 790357655, 115058691,
# 692669736, 306777988, 403262520, 432660095,
# 998244345, 565584242, 594981817, 691466349,
# 305574601, 883185646, 207886682, 981357622]
four_step_result = four_step_ntt(P, R, C)
print(four_step_result)
# [120, 692669736, 998244345, 305574601,
# 16886715, 306777988, 565584242, 883185646,
# 790357655, 403262520, 594981817, 207886682,
# 115058691, 432660095, 691466349, 981357622]