FFT和NTT

FFT

原理

参考以下文章:

\(\omega_n\)表示n次单位根。FFT其实就是计算如下矩阵乘法: \[ \begin{bmatrix} 1&1&1&\dots&1\\ 1&(\omega_n^1)^1&(\omega_n^1)^2&\dots&(\omega_n^1)^{n-1}\\ 1&(\omega_n^2)^1&(\omega_n^2)^2&\dots&(\omega_n^2)^{n-1}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&(\omega_n^{n-1})^1&(\omega_n^{n-1})^2&\dots&(\omega_n^{n-1})^{n-1} \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1} \end{bmatrix} = \begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1} \end{bmatrix} \] 其中\(a_i\)是多项式的系数表示,\(y_i\)是多项式的点值表示。IFFT即点值表示乘以上面方阵的逆矩阵,求出系数表示,其中逆矩阵如下。(注意逆矩阵前面乘了\(\frac{1}{n}\),因此IFFT做完后要除以\(n\))。 \[ \frac{1}{n} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & (\overline{\omega_n^1})^1 & (\overline{\omega_n^1})^2 & \cdots & (\overline{\omega_n^1})^{n-1} \\ 1 & (\overline{\omega_n^2})^1 & (\overline{\omega_n^2})^2 & \cdots & (\overline{\omega_n^2})^{n-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & (\overline{\omega_n^{n-1}})^1 & (\overline{\omega_n^{n-1}})^2 & \cdots & (\overline{\omega_n^{n-1}})^{n-1} \end{bmatrix} \]

代码

import cmath
import numpy as np
from numpy.polynomial import Polynomial


def FFT(P, inv=False):
    """
    P: [p_0, p_1, ..., p_{n-1}]
    """
    n = len(P)
    assert n & (n - 1) == 0     # n is a power of 2
    if n == 1:
        return P

    Pe = P[::2]
    Po = P[1::2]
    ye = FFT(Pe, inv)
    yo = FFT(Po, inv)

    w_n = cmath.exp((-2j if inv else 2j) * cmath.pi / n)
    w = 1
    y = [0] * n
    for k in range(n // 2):
        y[k] = ye[k] + w * yo[k]
        y[k + n // 2] = ye[k] - w * yo[k]
        w *= w_n
    return y


def polynomial_multiply(a, b):
    # 计算最终长度
    n = len(a)
    m = len(b)
    size = 1
    while size < n + m - 1:
        size <<= 1

    # 补零
    a += [0] * (size - n)
    b += [0] * (size - m)

    # FFT变换
    a_fft = FFT(a)
    b_fft = FFT(b)

    # 点值相乘
    c_fft = [a_fft[i] * b_fft[i] for i in range(size)]

    # IFFT变换
    c = FFT(c_fft, True)
    c = [round(x.real / size) for x in c]   # 注意IFFT后要除以size

    return c[:n + m - 1]


def numpy_polynomial_multiply(a_coe, b_coe):
    p1 = Polynomial(a_coe)
    p2 = Polynomial(b_coe)
    product = p1 * p2
    return product.coef


def main():
    a = [1, 2, 3]
    b = [4, 5, 6, 7]

    result = polynomial_multiply(a, b)
    print("FFT:", result)
    # [4, 13, 28, 34, 32, 21]

    numpy_result = numpy_polynomial_multiply(a, b)
    print("NumPy:", np.round(numpy_result).astype(int))


if __name__ == "__main__":
    main()

NTT

原理

参考以下文章:

就是把n次单位根替换为了原根。

代码

M = 998244353       # 常用NTT模数,998244353 = 2 ^ 23 * 119 + 1
G = 3               # 原根
G_INV = 332748118


def NTT(P, inv=False):
    """
    P: [p_0, p_1, ..., p_{n-1}]
    """
    n = len(P)
    assert n & (n - 1) == 0  # n is a power of 2
    if n == 1:
        return P

    Pe = P[::2]
    Po = P[1::2]
    ye = NTT(Pe, inv)
    yo = NTT(Po, inv)

    w_n = pow(G_INV if inv else G, (M - 1) // n, M)
    w = 1
    y = [0] * n
    for k in range(n // 2):
        y[k] = (ye[k] + w * yo[k]) % M
        y[k + n // 2] = (ye[k] - w * yo[k]) % M
        w = w * w_n % M
    return y


def polynomial_multiply_ntt(a, b):
    # 计算最终长度
    n = len(a)
    m = len(b)
    size = 1
    while size < n + m - 1:
        size <<= 1

    # 补零
    a += [0] * (size - n)
    b += [0] * (size - m)

    # NTT变换
    a_ntt = NTT(a)
    b_ntt = NTT(b)

    # 点值相乘
    c_ntt = [a_ntt[i] * b_ntt[i] % M for i in range(size)]

    # INTT变换
    c = NTT(c_ntt, True)
    c = [x * pow(size, -1, M) % M for x in c]   # 除以size

    return c[:n + m - 1]


def main():
    a = [1, 2, 3]
    b = [4, 5, 6, 7]
    result = polynomial_multiply_ntt(a, b)
    print(result)


if __name__ == "__main__":
    main()

以下是使用sage计算多项式乘法,用于验证上述代码。

from sage.all import *

M = 998244353
a_coe = [1, 2, 3]
b_coe = [4, 5, 6, 7]

R = PolynomialRing(GF(M), 'x')
a = R(a_coe)
b = R(b_coe)
product = a * b

print(list(product))

FFT和NTT
https://shuusui.site/blog/2025/05/10/algo-fft-ntt/
作者
Shuusui
发布于
2025年5月10日
更新于
2025年5月10日
许可协议