蒙哥马利算法

概述

参考资料:Montgomery modular multiplication - Wikipedia

蒙哥马利算法包括模乘、约减和模幂。其主要思路是:通过选择合适的参数\(R\),将数转化为蒙哥马利形式,在该形式下除法可以用移位代替,从而简化了计算。进行模乘时,需要先将数转化为蒙哥马利形式,计算完毕后再转换回原形式,因此单次乘法开销较普通模乘大,但乘法数量多时,由于只需要在开始和结束时转换,因此效率很高。

标准蒙哥马利算法

转换为蒙哥马利形式

\(a\)\(b\)为乘数,\(N\)为模数。首先需要寻找一个参数\(R\),满足以下两个条件:①\(R>N\);②\(gcd(R,N)=1\),即\(R\)\(N\)互质。将一个数转化为其蒙哥马利形式,只需乘以\(R\)再取模\(N\),如\(a\)\(b\)的蒙哥马利形式为:\(aR\bmod N\)\(bR\bmod N\),记作\(\overline a\)\(\overline b\)

对于程序设计来说\(R\)通常选择为2的次幂,这样除以\(R\)可以用移位来简化计算。另外以上转换并不会损失任何信息,因为\(R\)\(N\)互质,转换前后的两个群是同构的。

\(N=17\)\(R=100\)为例:

原形式 转换 蒙哥马利形式
\(3\) \(3\times 100\bmod17=11\) \(11\)
\(7\) \(7\times 100\bmod17=3\) \(3\)
\(15\) \(15\times 100\bmod17=4\) \(4\)

蒙哥马利形式下的计算

蒙哥马利形式下加减法的结果是原形式下结果的蒙哥马利形式,因为: \[ \begin{align} \overline a+\overline b\equiv aR\bmod N+bR\bmod N\equiv (a+b)R\bmod N\equiv \overline{a+b}\pmod N\\ \overline a-\overline b\equiv aR\bmod N-bR\bmod N\equiv (a-b)R\bmod N\equiv \overline{a-b}\pmod N \end{align} \] 但乘法不一样,\(ab\)的蒙哥马利形式为\(\overline{ab}=abR\bmod N\),而将蒙哥马利形式下的\(a\)\(b\)相乘结果如下: \[ \overline a\times\overline b\equiv(aR\bmod N)(bR\bmod N)\equiv(abR)R\bmod N\ne\overline{ab}\pmod N \]

不难发现上式结果多了一个\(R\),因此要得到\(\overline{ab}\)还得再除以\(R\),即乘以\(R\)的逆元\(R^{-1}\),由于\(R\)\(N\)互质,逆元一定存在。

仍以\(N=17\)\(R=100\)为例,设\(a=7\)\(b=15\):因为\(100\times8\bmod17=1\),所以\(R\)的逆元\(R^{-1}=8\)。原形式的乘法\(ab\bmod N=7\times15\bmod17=3\),蒙哥马利形式的乘法\(\overline{ab}=\overline a\overline bR^{-1}\bmod N=3\times4\times8\bmod17=11\),恰好对应\(3\)的蒙哥马利形式。

蒙哥马利约减算法

上节中计算蒙哥马利乘法需要乘以\(R^{-1}\)再约减,虽然这样计算没有问题,但约减的计算量很大,因此有了这节的蒙哥马利约减算法,简称REDC。先看算法:

  • 输入\(T\)\(R\)\(k\)\(N\)\(N'\)。满足\(N'N\equiv-1\pmod R\)\(T\in[0,RN-1]\)\(R=2^k\)

  • 输出\(S=TR^{-1}\bmod N\)

  • 算法 \[ \begin{align} m&\leftarrow TN' \bmod R\\ t&\leftarrow(T+mN)>>k\\ S&=\begin{cases} t-N&t\ge N\\ t&t<N \end{cases} \end{align} \]

原理非常简单,我们希望乘以\(R^{-1}\)能通过右移\(k\)位计算来完成,但现在\(T\)的低位不是全0的,直接移位结果肯定不正确。因此想到给\(T\)加上一定数量(\(m\)个)的\(N\)(因为\(N\)是模数,加上\(mN\)不影响结果),使得\(T+mN\)\(R\)的倍数,这样低位全0就可以移位了。

那这个\(m\)是多少呢?假设\(T+mN\)\(R\)的倍数,即\(T+mN\equiv0\pmod R\),则\(m=-TN^{-1}\),也就是算法第一行。

由于\(m\in[0,R-1]\),有\(T+mN\in[0,RN-1+N(R-1)]<2RN\),所以\(t\in[0,2N-1]\),只需至多一次减法即可将\(t\)约减至\([0,N-1]\)

完整程序

# 约简 X / R
def reduction(T, N, R, k):
    N_ = -pow(N, -1, R)     # N' = -N^-1 (mod R)
    m = T * N_ % R          # m = TN' (mod R)
    y = (T + m * N) >> k    # X + mN 除以R相当于右移k位
    if y >= N:
        return y - N
    else:
        return y


# 蒙哥马利模乘:a * b % N
def montgomery_mult(A, B, N, k, R, is_mont_form=False):
    if not is_mont_form:
        # 转换
        A = A * R % N
        B = B * R % N

    # 求出 a' * b' 的蒙哥马利形式
    T = A * B                     # T = AB*RR (mod N)
    T1 = reduction(T, N, R, k)      # T1 = T / R = ABR (mod N)

    print(hex(T1))

    # 最后除以R得到结果
    T2 = reduction(T1, N, R, k)  # T2 = T1 / R = AB (mod N)
    return T2


def main():
    A = 12
    B = 50
    N = 29
    k = 6
    R = 2 ** k

    print("A:", hex(A))
    print("B:", hex(B))
    print(hex(A * B * pow(R, -1, N) % N))
    montgomery_mult(A, B, N, k, R, is_mont_form=True)

if __name__ == '__main__':
    main()

基于字的蒙哥马利算法

在很多硬件上不可能一次做完很大位宽的乘法,比如在64位的机器上做256位的模乘。因此就有了本节所谓的Word-Based Montgomery Modular Multiplication,简称Word-Based MMM,或Radix MMM,或维基百科的MultiPrecisionREDC。

约简算法

前面的思想是给\(T\)加上\(mN\),使得\(T+mN\)的低位有\(k\)个0,这样乘以\(R^{-1}\)就相当于右移\(k\)位。既然能一次性右移\(k\)位,在计算位宽不够的情况下能不能通过多次移位来达到同样的效果呢?这显然是可以的。

以具体的数据举例,设\(N\)是20位的模数,则可以取\(R=2^{20}\),假设运算器只支持4位的计算(\(k=4\)),设\(B=2^k=2^4=16\)\(B\)\(R\)的用法类似,都是为了通过右移来简化除法的),\(R=B^r\)\(r=5\)),\(T\)是要约简的数(\(0\le T<RN\))。另外设\(N'=-N^{-1}\pmod B\),这里模\(R\)也可以,因为计算的时候也会模\(B\)。我们称一个的长度为4位,可以理解为一个存储单元是4位,更大位宽的数用数组表示。那么模数\(N\)\(p=5\)个字(长度为5的数组),\(T\)\(r+p=10\)个字。这里\(r\)恰好等于\(p\)是因为\(R\)刚好是比\(N\)大一点的2的幂,比如\(R\)\(2^{24}\)\(r\)就不等于\(p\)了。

下面正式进入算法。算法一共分为\(r\)轮,每一轮(\(i\in[0,r)\))都是为了给\(T\)加上一定数量的\(N\)使得每一轮结束后\(T\)的后\(i+1\)个字全为0。抽象地说就是每一轮使得\(T+mN\equiv0\pmod{B^{i+1}}\)。这里非常重要,下面举例子说明:

  • 第0轮:要使得\(T+mN\equiv0\pmod{B}\),则\(m=-TN^{-1}\bmod B=TN'\bmod B\),这里\(T\)只需取最低的字即可,即\(m=T[0]N'\bmod B\)。然后令\(T\leftarrow T+mN\),至此现在的\(T\)就能被\(B\)整除了,即\(T\)的最低字全为0。
  • 第1轮:要使得\(T+mN\equiv0\pmod{B^2}\),显然这里加的\(mN\)应该是\(B\)的倍数(因为\(T\)目前已经是\(B\)的倍数了),所以我们明确写出来,即要使得\(T+mBN\equiv0\pmod{B^2}\)。根据同余的性质,等式和模数可以同时消去\(B\),即\(T[1:]+mN\equiv0\pmod B\)。所以\(m=-T[1:]N^{-1}\bmod B\),同样\(T\)只需取最低位,即\(m=T[1]N'\bmod B\)。然后令\(T\leftarrow T+mBN\),至此\(T\)就能被\(B^2\)整除了。
  • 第2轮:同理计算\(m=T[2]N'\bmod B\),然后令\(T\leftarrow T+mB^2N\),至此\(T\)就能被\(B^3\)整除了。
  • \(i\)轮:同理计算\(m=T[i]N'\bmod B\),然后令\(T\leftarrow T+mB^iN\),至此\(T\)就能被\(B^{i+1}\)整除了。
  • \(r-1\)轮,\(T\)已经能被\(B^5\)\(R\)整除了。

细节的部分直接上代码:

def reduction(T, N, p, R, r, B, k):
    """
    :param T: 待约简的数(r + p个字)
    :param N: 模数(p个字)
    :param p: N的字数
    :param R: R(r个字)
    :param r: R的字数
    :param B: B
    :param k: B的位数
    :return: T * R^{-1} % N
    """
    assert B == 2 ** k
    assert R == B ** r

    NN = N      # 备份一下
    N_ = pow(-N, -1, B)         # N' = -N^{-1} % B

    # 将T和N表示成字的数组
    T = [(T >> (i * k)) % B for i in range(r + p)]
    N = [(N >> (i * k)) % B for i in range(p)]

    T.append(0)     # 在最高位加一个字,防止后面溢出,所以T一共是p+1个字

    # 每一轮循环使得T的后i+1个字全为0
    for i in range(r):
        c = 0               # c是进位,后面会用到
        m = T[i] * N_ % B   # m = T[i] * N' % B

        # 下面两个循环都是在计算:T = T + m * B^i * N
        # 因为只能计算k位的乘法,所以只能依次用m去乘N[j]
        # 至于乘B^i,直接跳过低位的字,即使用T[i+j]
        for j in range(p):                  # 因为是与N相乘,N有p个字就循环p次
            t = T[i + j] + m * N[j] + c     # 计算中间结果
            T[i + j] = t % B                # 低位存储在当前字
            c = t // B                      # 高位进位到下一个字
        # 这个循环是为了处理多余的进位,只需处理到 r+p (闭区间)
        for j in range(p, r + p - i + 1):   # 即 i+j 从 p+i 循环到 r+p
            t = T[i + j] + c
            T[i + j] = t % B
            c = t // B

    # 除以R,即取前p+1个字
    S = 0
    for i in range(p + 1):
        S = S + T[i + r] * (B ** i)

    if S >= NN:
        return S - NN
    else:
        return S

def main():
    p = 5
    N = 0xf0001
    k = 4
    B = 2 ** k
    r = 5
    R = B ** r
    T = 0x123456789

    print(hex(T * pow(R, -1, N) % N))
    print(hex(reduction(T, N, p, R, r, B, k)))


if __name__ == '__main__':
    main()

完整算法

与上面几乎同理,只不过\(T\)没有提前算好,而是一边迭代一边计算。

def to_words(val, k, num_words):
    """将大整数转换为字数组 (Little Endian)"""
    words = []
    for _ in range(num_words):
        words.append(val % (2 ** k))
        val //= 2 ** k
    return words


def from_words(words, k):
    """将字数组转回大整数"""
    val = 0
    for i, w in enumerate(words):
        val += w * (2 ** (k * i))
    return val


def word_based_MMM(X, Y, N, k, n):
    """
    :param X: [0, N)
    :param Y: [0, N)
    :param N: 模数
    :param k: 字长
    :param n: R = 2^n, R末尾0的个数,一般也是模数的位数
    """
    # 字的个数,s = ceil(n / k)
    s = (n + k - 1) // k

    # M' = -M^{-1} mod 2^k,模 2^n 也可
    M_ = pow(-N, -1, 2 ** k)

    # 将输入转为字数组
    X_arr = to_words(X, k, s)
    Y_arr = to_words(Y, k, s)
    M_arr = to_words(N, k, s)

    # 初始化 Z 数组,由于计算过程中会有进位,多加了1位
    Z_arr = [0] * (s + 1)

    # 外循环,遍历 X 的每一个字
    for i in range(s):
        # 计算 m
        tmp_sum = Z_arr[0] + X_arr[i] * Y_arr[0]
        m = (tmp_sum * M_) % (2 ** k)

        carry = 0  # 将Ca和Cb合并了

        # 内循环,将 j=0 也合并进去了
        for j in range(s):
            # 累加 X[i] * Y[j] 和 m * M[j]
            sum_val = Z_arr[j] + X_arr[i] * Y_arr[j] + m * M_arr[j] + carry

            # 右移 k
            if j > 0:
                Z_arr[j - 1] = sum_val % (2 ** k)

            # 高 k 位作为进位传给下一次 j 循环
            carry = sum_val // (2 ** k)

        # 处理剩余的进位
        sum_val = Z_arr[s] + carry
        Z_arr[s - 1] = sum_val % (2 ** k)
        Z_arr[s] = sum_val // (2 ** k)

    # 将 Z 数组转回大整数
    Z = from_words(Z_arr, k)

    if Z >= N:
        Z -= N
    return Z

if __name__ == "__main__":
    n = 20          # R = 2^n
    k = 4           # 字长
    N = 0xff001     # 模数
    R = 2 ** n
    X = 0x12345
    Y = 0x56789

    print((X * Y * pow(R, -1, N)) % N)
    print(word_based_MMM(X, Y, N, k, n))

蒙哥马利算法
https://shuusui.site/blog/2024/06/24/algo-montgomery/
作者
Shuusui
发布于
2024年6月24日
更新于
2024年6月24日
许可协议