FFT和NTT

FFT

参考以下文章:

本质

设多项式\(P(x)=a_{n-1}x^{n-1}+\cdots+a_1x+a_0\)\(a_i\)即是多项式的系数表示。设\(y_i\)是多项式的点值表示\(\omega_n\)表示n次单位根(\(e^{\frac{2\pi i}{n}}\))。FFT就是将\(\omega_n\)\(0\sim n-1\)次方依次代入多项式,即如下矩阵乘法: \[ \begin{bmatrix} 1&1&1&\dots&1\\ 1&(\omega_n^1)^1&(\omega_n^1)^2&\dots&(\omega_n^1)^{n-1}\\ 1&(\omega_n^2)^1&(\omega_n^2)^2&\dots&(\omega_n^2)^{n-1}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&(\omega_n^{n-1})^1&(\omega_n^{n-1})^2&\dots&(\omega_n^{n-1})^{n-1} \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1} \end{bmatrix} = \begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1} \end{bmatrix} \] 或写成: \[ y_k=\sum_{i=0}^{n-1}a_i\omega_n^{ki} \] IFFT即点值表示乘以上面方阵的逆矩阵,求出系数表示,其中逆矩阵如下。(注意逆矩阵前面乘了\(\frac{1}{n}\),因此IFFT做完后要除以\(n\))。 \[ \frac{1}{n} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & (\overline{\omega_n^1})^1 & (\overline{\omega_n^1})^2 & \cdots & (\overline{\omega_n^1})^{n-1} \\ 1 & (\overline{\omega_n^2})^1 & (\overline{\omega_n^2})^2 & \cdots & (\overline{\omega_n^2})^{n-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & (\overline{\omega_n^{n-1}})^1 & (\overline{\omega_n^{n-1}})^2 & \cdots & (\overline{\omega_n^{n-1}})^{n-1} \end{bmatrix} \]

算法

显然直接计算矩阵乘法复杂度为\(O(n^2)\),下面给出\(O(n\log n)\)的算法。

\(P(x)\)按次数的奇偶分成两部分: \[ \begin{align} P(x)=&(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})\\ +&(a_1x+a_3x^3+\cdots+a_{n-1}x^{n-1}) \end{align} \]\[ \begin{align} P_e(x)&=a_0+a_2x+\cdots+a_{n-2}x^{\frac{n-2}{2}}\\ P_o(x)&=a_1+a_3x+\cdots+a_{n-1}x^{\frac{n-2}{2}} \end{align} \]\[ P(x)=P_e(x^2)+xP_o(x^2) \] 代入\(x=\omega_n^k\)\(0\le k<\frac{n}{2}\))可得 \[ \begin{align} P(\omega_n^k)&=P_e(\omega_n^{2k})+\omega_n^kP_o(\omega_n^{2k})\\ &=P_e(\omega_{\frac{n}{2}}^k)+\omega_n^kP_o(\omega_{\frac{n}{2}}^k) \end{align} \] 再代入\(x=\omega_n^{k+\frac{n}{2}}\)\(0\le k<\frac{n}{2}\))可得 \[ \begin{align} P(\omega_n^{k+\frac{n}{2}})&=P_e(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}P_o(\omega_n^{2k+n})\\ &=P_e(\omega_n^{2k})+\omega_n^{k+\frac{n}{2}}P_o(\omega_n^{2k})\\ &=P_e(\omega_{\frac{n}{2}}^k)-\omega_n^kP_o(\omega_{\frac{n}{2}}^k) \end{align} \] > 注意旋转因子(Twiddle Factor)\(\omega_n^k\)\(0\le k<\frac{n}{2}\)),和\(P_e,P_o\)中的自变量下标不一样。

所谓蝶形单元即上面推导的最后一行,一个是\(P_e+\omega_n^kP_o\),一个是\(P_e-\omega_n^kP_o\)

因此,递归如下:

  • 要求\(P(\omega_n^k)\)\(P(\omega_n^{k+\frac{n}{2}})\)\(0\le k<\frac{n}{2}\),只需求\(P_e(\omega_{\frac{n}{2}}^k)\)\(P_o(\omega_{\frac{n}{2}}^k)\)
  • 要求\(P_e(\omega_{\frac{n}{2}}^k)\)\(P_e(\omega_{\frac{n}{2}}^{k+\frac{n}{4}})\)\(0\le k<\frac{n}{4}\),只需求\(P_e(\omega_{\frac{n}{4}}^k)\)\(P_o(\omega_{\frac{n}{4}}^k)\)\(P_o(w_{\frac{n}{2}}^k)\)同理;
  • ……
  • 要求\(P_e(\omega_2^k)\)\(P_e(\omega_2^{k+1})\)\(0\le k<1\),只需求\(P_e(w_1^k)\)\(P_o(\omega_1^k)\)
  • \(P_e(w_1^k)\)\(P_o(\omega_1^k)\)\(P_e(1)\)\(P_o(1)\)直接返回。

代码

import cmath
import numpy as np
from numpy.polynomial import Polynomial

def FFT(P, inv=False):
    """
    P: [p_0, p_1, ..., p_{n-1}]
    """
    n = len(P)
    assert n & (n - 1) == 0     # n是2的幂
    if n == 1:
        return P

    Pe = P[::2]                 # 取偶次项 f0
    Po = P[1::2]                # 取奇次项 f1
    ye = FFT(Pe, inv)
    yo = FFT(Po, inv)

    w_n = cmath.exp((-2j if inv else 2j) * cmath.pi / n)    # n次单位根
    w = 1                                   # 用于迭代
    y = [0] * n
    for k in range(n // 2):                 # 0 <= k < n/2
        y[k] = ye[k] + w * yo[k]            # f0 + w^k f1
        y[k + n // 2] = ye[k] - w * yo[k]   # f0 - w^k f1
        w *= w_n
    return y

def polynomial_multiply(a, b):
    # 计算最终长度(大于等于n+m+1的最小的2的幂)
    n = len(a)
    m = len(b)
    size = 1
    while size < n + m - 1:
        size <<= 1

    # 补零
    a += [0] * (size - n)
    b += [0] * (size - m)

    # FFT变换
    a_fft = FFT(a)
    b_fft = FFT(b)

    # 点值相乘
    c_fft = [a_fft[i] * b_fft[i] for i in range(size)]

    # IFFT变换
    c = FFT(c_fft, True)
    c = [round(x.real / size) for x in c]   # 注意IFFT后要除以size

    return c[:n + m - 1]

def numpy_polynomial_multiply(a_coe, b_coe):
    # 创建多项式对象
    p1 = Polynomial(a_coe)
    p2 = Polynomial(b_coe)

    # 多项式乘法
    product = p1 * p2

    # 返回乘积的系数(从低次到高次)
    return product.coef

def main():
    # a = [1, 2, 3]
    # b = [4, 5, 6, 7]
    a = [27, 0, -33, 61, 95, -81, 65]
    b = [-25, 63, -37, -10, 55, 67, -72, 1]

    result = polynomial_multiply(a, b)

    print("FFT:", result)
    # [4, 13, 28, 34, 32, 21]

    numpy_result = numpy_polynomial_multiply(a, b)
    print("NumPy:", np.round(numpy_result).astype(int))

if __name__ == "__main__":
    main()

注意事项

标准FFT实现使用的n次单位根是\(\omega_n=e^{\frac{-2\pi i}{n}}\),和前面的\(\omega_n\)差一个负号,所以上面的FFT函数实际上相当于标准IFFT,验证如下:

P = list(range(4))

Y1 = FFT(P)
print(Y1)
# [6, (-2-2j), -2, (-1.9999999999999998+2j)]

Y2 = np.fft.ifft(P) * 4
print(Y2)
# [ 6.+0.j -2.-2.j -2.+0.j -2.+2.j]

优化

上面的软件实现每次递归都会占用很多内存来存储\(P_e\)\(P_o\),而我们更希望整个操作都在同一个数组里实现,不增加额外的存储空间。以8个元素为例,我们来模拟软件递归时进行的二分。

  • (A)原系数:\((a_0\ a_1\ a_2\ a_3\ a_4\ a_5\ a_6\ a_7)\)
  • (B)二分一次:\((a_0\ a_2\ a_4\ a_6)(a_1\ a_3\ a_5\ a_7)\)
  • (C)二分两次:\((a_0\ a_4)(a_2\ a_6)(a_1\ a_5)(a_3\ a_7)\)
  • (D)二分三次:\((a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7)\)

按上述分法,把\(P_e\)存储在数组前面,\(P_o\)存储在后面。以第一次二分为例,A中的\(a_0,a_4\)是用B中的\(a_0,a_1\)蝶形计算的,它们在数组中的位置正好相同。因此如果已经得到了B,那么只需用B的\(a_0,a_1\)计算出A中的\(a_0,a_5\)然后覆写到数组同样的位置即可,其余同理。这样整个操作都在同一个数组中进行,无需额外的存储,而且数据的访问很规整。

按上述实现,输入的数组D应该按照下标为04261537的顺序给出,这正好是01234567的二进制位逆序\(O(n)\)时间内进行逆序置换的算法参考这里。所以先对输入进行一次位逆序置换,然后再使用上述FFT,就能得到正确的输出;如果输入没有置换,那么FFT得到的输出就是位逆序的,需要进行置换才能得到正确结果。

NTT

参考以下文章:

原理

就是把n次单位根\(w_n\)替换为了原根\(g_n=g^{\frac{p-1}{n}}\)

代码

M = 998244353       # 常用NTT模数,998244353 = 2 ^ 23 * 119 + 1
G = 3               # 原根
G_INV = 332748118


def NTT(P, inv=False):
    """
    P: [p_0, p_1, ..., p_{n-1}]
    """
    n = len(P)
    assert n & (n - 1) == 0  # n is a power of 2
    if n == 1:
        return P

    Pe = P[::2]
    Po = P[1::2]
    ye = NTT(Pe, inv)
    yo = NTT(Po, inv)

    w_n = pow(G_INV if inv else G, (M - 1) // n, M)
    w = 1
    y = [0] * n
    for k in range(n // 2):
        y[k] = (ye[k] + w * yo[k]) % M
        y[k + n // 2] = (ye[k] - w * yo[k]) % M
        w = w * w_n % M
    return y


def polynomial_multiply_ntt(a, b):
    # 计算最终长度
    n = len(a)
    m = len(b)
    size = 1
    while size < n + m - 1:
        size <<= 1

    # 补零
    a += [0] * (size - n)
    b += [0] * (size - m)

    # NTT变换
    a_ntt = NTT(a)
    b_ntt = NTT(b)

    # 点值相乘
    c_ntt = [a_ntt[i] * b_ntt[i] % M for i in range(size)]

    # INTT变换
    c = NTT(c_ntt, True)
    c = [x * pow(size, -1, M) % M for x in c]   # 除以size

    return c[:n + m - 1]


def main():
    a = [1, 2, 3]
    b = [4, 5, 6, 7]
    result = polynomial_multiply_ntt(a, b)
    print(result)


if __name__ == "__main__":
    main()

以下是使用sage计算多项式乘法,用于验证上述代码。

from sage.all import *

M = 998244353
a_coe = [1, 2, 3]
b_coe = [4, 5, 6, 7]

R = PolynomialRing(GF(M), 'x')
a = R(a_coe)
b = R(b_coe)
product = a * b

print(list(product))

4步NTT

原理

回顾NTT的计算公式: \[ y_k=\sum_{n=0}^{N-1}a_n\omega^{kn} \] 现在将长度为\(N\)的一维数组\([a_0,a_1,\cdots,a_{N-1}]\)看成是\(R\)\(C\)列的二维数组(\(N=R\cdot C\)),设输入输出矩阵的索引为: \[ \begin{align} n&=i\cdot C+j\\ k&=s\cdot R+r \end{align} \]\(i\)\(r\)表示行,\(j\)\(s\)表示列。代入NTT的公式可得: \[ \begin{align} y[sR+r]&=\sum_{j=0}^{C-1}\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{(iC+j)(sR+r)}\\ &=\sum_{j=0}^{C-1}\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{iCr}\cdot\omega^{jr}\cdot\omega^{sRj}\\ &=\sum_{j=0}^{C-1}\omega^{sRj}\cdot\left[\sum_{i=0}^{R-1}a[iC+j]\cdot\omega^{iCr}\right]\cdot\omega^{jr} \end{align} \] 写成二维数组的形式为: \[ y[r][s]=\sum_{j=0}^{C-1}\omega^{sRj}\cdot\left[\sum_{i=0}^{R-1}a[i][j]\cdot\omega^{iCr}\right]\cdot\omega^{jr} \]

进行如下说明:

  • 二维数组\(arr[\cdot][\cdot]\)的第1个中括号表示行索引,第2个中括号表示列索引。

  • 输入数组\(a[i][j]\)在内存中是按行优先存储的。

  • 索引\(n\)是行优先,\(k\)是列优先。这样\(\omega^{kn}\)才会有一项是\(\omega^{RC}\),易知\(\omega^{RC}=1\),所以上面省略了这一项。

因此可以通过以下4步来计算NTT:

  • 对二维数组的每一列做长度为\(R\)的NTT。旋转因子为\(\omega^C\),计算完成后写回原来的列。对应上面公式中的 \[ \sum_{i=0}^{R-1}a[i][j]\cdot\omega^{iCr} \]

  • 对二维数组的每个元素乘以\(\omega^{ij}\)\(i\)\(j\)为该元素的行列号)。对应上面公式中的\(\omega^{jr}\)

  • 对二维数组的每一行做长度为\(C\)的NTT。旋转因子为\(\omega^R\),计算完成后写回原来的行。即上面公式中的 \[ \sum_{j=0}^{C-1}\omega^{sRj}\cdot[] \]

  • 对二维数组进行转置

为什么要转置?经过前3步的处理,\(y[r][s]\)的值确实已经存到了二维数组的第\(r\)行第\(s\)列。但是输出二维数组的索引是\(k=s\cdot R+r\),即\(y[0]=y[0\cdot R+0]=y[0][0]\)\(y[1]=y[0\cdot R+1]=y[1][0]\)\(y[2]=y[0\cdot R+2]=y[2][0]\)、……。因此正确的输出顺序应该是按列优先输出,即\(y[0][0],y[1][0],\cdots\),所以需要转置。

代码

def four_step_ntt(P, R, C):
    assert len(P) == R * C
    P_matrix = np.array(P).reshape((R, C))  # row-major

    # Step 1: NTT on each column (length R)
    # w1 = pow(G, (M - 1) // (R * C) * C, M)
    for j in range(C):
        col = [P_matrix[i][j] for i in range(R)]
        col_ntt = NTT(col)
        for i in range(R):
            P_matrix[i][j] = col_ntt[i]

    # Step 2: Multiply by twiddle factors
    w = pow(G, (M - 1) // (R * C), M)
    for i in range(R):
        for j in range(C):
            twiddle = pow(w, i * j, M)
            P_matrix[i][j] = P_matrix[i][j] * twiddle % M

    # Step 3: NTT on each row (length C)
    # w2 = pow(G, (M - 1) // (R * C) * R, M)
    for i in range(R):
        row = P_matrix[i]
        row_ntt = NTT(row.tolist())
        P_matrix[i] = row_ntt

    # Step 4 skipped (no transpose)
    return P_matrix.flatten().tolist()


# Test input
P = [i for i in range(16)]
N = len(P)
R, C = 4, 4  # R * C = N

# Perform both NTT and 4-step NTT
ntt_result = NTT(P)
print(ntt_result)
# [120, 16886715, 790357655, 115058691,
#  692669736, 306777988, 403262520, 432660095,
#  998244345, 565584242, 594981817, 691466349,
#  305574601, 883185646, 207886682, 981357622]

four_step_result = four_step_ntt(P, R, C)
print(four_step_result)
# [120, 692669736, 998244345, 305574601,
#  16886715, 306777988, 565584242, 883185646,
#  790357655, 403262520, 594981817, 207886682,
#  115058691, 432660095, 691466349, 981357622]

FFT和NTT
https://shuusui.site/blog/2025/05/10/algo-fft-ntt/
作者
Shuusui
发布于
2025年5月10日
更新于
2025年8月3日
许可协议