蒙哥马利算法
概述
蒙哥马利算法包括模乘、约减和模幂。其主要思路是:通过选择合适的参数\(R\),将数转化为蒙哥马利形式,在该形式下除法可以用移位代替,从而简化了计算。进行模乘时,需要先将数转化为蒙哥马利形式,计算完毕后再转换回原形式,因此单次乘法开销较普通模乘大,但乘法数量多时,由于只需要在开始和结束时转换,因此效率很高。
标准蒙哥马利算法
转换为蒙哥马利形式
设\(a\)、\(b\)为乘数,\(N\)为模数。首先需要寻找一个参数\(R\),满足以下两个条件:①\(R>N\);②\(gcd(R,N)=1\),即\(R\)与\(N\)互质。将一个数转化为其蒙哥马利形式,只需乘以\(R\)再取模\(N\),如\(a\)、\(b\)的蒙哥马利形式为:\(aR\bmod N\)和\(bR\bmod N\),记作\(\overline a\)和\(\overline b\)。
对于程序设计来说\(R\)通常选择为2的次幂,这样除以\(R\)可以用移位来简化计算。另外以上转换并不会损失任何信息,因为\(R\)与\(N\)互质,转换前后的两个群是同构的。
以\(N=17\),\(R=100\)为例:
| 原形式 | 转换 | 蒙哥马利形式 |
|---|---|---|
| \(3\) | \(3\times 100\bmod17=11\) | \(11\) |
| \(7\) | \(7\times 100\bmod17=3\) | \(3\) |
| \(15\) | \(15\times 100\bmod17=4\) | \(4\) |
蒙哥马利形式下的计算
蒙哥马利形式下加减法的结果是原形式下结果的蒙哥马利形式,因为: \[ \begin{align} \overline a+\overline b\equiv aR\bmod N+bR\bmod N\equiv (a+b)R\bmod N\equiv \overline{a+b}\pmod N\\ \overline a-\overline b\equiv aR\bmod N-bR\bmod N\equiv (a-b)R\bmod N\equiv \overline{a-b}\pmod N \end{align} \] 但乘法不一样,\(ab\)的蒙哥马利形式为\(\overline{ab}=abR\bmod N\),而将蒙哥马利形式下的\(a\)和\(b\)相乘结果如下: \[ \overline a\times\overline b\equiv(aR\bmod N)(bR\bmod N)\equiv(abR)R\bmod N\ne\overline{ab}\pmod N \]
不难发现上式结果多了一个\(R\),因此要得到\(\overline{ab}\)还得再除以\(R\),即乘以\(R\)的逆元\(R^{-1}\),由于\(R\)与\(N\)互质,逆元一定存在。
仍以\(N=17\),\(R=100\)为例,设\(a=7\),\(b=15\):因为\(100\times8\bmod17=1\),所以\(R\)的逆元\(R^{-1}=8\)。原形式的乘法\(ab\bmod N=7\times15\bmod17=3\),蒙哥马利形式的乘法\(\overline{ab}=\overline a\overline bR^{-1}\bmod N=3\times4\times8\bmod17=11\),恰好对应\(3\)的蒙哥马利形式。
蒙哥马利约减算法
上节中计算蒙哥马利乘法需要乘以\(R^{-1}\)再约减,虽然这样计算没有问题,但约减的计算量很大,因此有了这节的蒙哥马利约减算法,简称REDC。先看算法:
输入:\(T\),\(R\),\(k\),\(N\),\(N'\)。满足\(N'N\equiv-1\pmod R\),\(T\in[0,RN-1]\),\(R=2^k\)。
输出:\(S=TR^{-1}\bmod N\)。
算法 \[ \begin{align} m&\leftarrow TN' \bmod R\\ t&\leftarrow(T+mN)>>k\\ S&=\begin{cases} t-N&t\ge N\\ t&t<N \end{cases} \end{align} \]
原理非常简单,我们希望乘以\(R^{-1}\)能通过右移\(k\)位计算来完成,但现在\(T\)的低位不是全0的,直接移位结果肯定不正确。因此想到给\(T\)加上一定数量(\(m\)个)的\(N\)(因为\(N\)是模数,加上\(mN\)不影响结果),使得\(T+mN\)是\(R\)的倍数,这样低位全0就可以移位了。
那这个\(m\)是多少呢?假设\(T+mN\)是\(R\)的倍数,即\(T+mN\equiv0\pmod R\),则\(m=-TN^{-1}\),也就是算法第一行。
由于\(m\in[0,R-1]\),有\(T+mN\in[0,RN-1+N(R-1)]<2RN\),所以\(t\in[0,2N-1]\),只需至多一次减法即可将\(t\)约减至\([0,N-1]\)。
完整程序
# 约简 X / R
def reduction(T, N, R, k):
N_ = -pow(N, -1, R) # N' = -N^-1 (mod R)
m = T * N_ % R # m = TN' (mod R)
y = (T + m * N) >> k # X + mN 除以R相当于右移k位
if y >= N:
return y - N
else:
return y
# 蒙哥马利模乘:a * b % N
def montgomery_mult(A, B, N, k, R, is_mont_form=False):
if not is_mont_form:
# 转换
A = A * R % N
B = B * R % N
# 求出 a' * b' 的蒙哥马利形式
T = A * B # T = AB*RR (mod N)
T1 = reduction(T, N, R, k) # T1 = T / R = ABR (mod N)
print(hex(T1))
# 最后除以R得到结果
T2 = reduction(T1, N, R, k) # T2 = T1 / R = AB (mod N)
return T2
def main():
A = 12
B = 50
N = 29
k = 6
R = 2 ** k
print("A:", hex(A))
print("B:", hex(B))
print(hex(A * B * pow(R, -1, N) % N))
montgomery_mult(A, B, N, k, R, is_mont_form=True)
if __name__ == '__main__':
main()基于字的蒙哥马利算法
在很多硬件上不可能一次做完很大位宽的乘法,比如在64位的机器上做256位的模乘。因此就有了本节所谓的Word-Based Montgomery Modular Multiplication,简称Word-Based MMM,或Radix MMM,或维基百科的MultiPrecisionREDC。
约简算法
前面的思想是给\(T\)加上\(mN\),使得\(T+mN\)的低位有\(k\)个0,这样乘以\(R^{-1}\)就相当于右移\(k\)位。既然能一次性右移\(k\)位,在计算位宽不够的情况下能不能通过多次移位来达到同样的效果呢?这显然是可以的。
以具体的数据举例,设\(N\)是20位的模数,则可以取\(R=2^{20}\),假设运算器只支持4位的计算(\(k=4\)),设\(B=2^k=2^4=16\)(\(B\)和\(R\)的用法类似,都是为了通过右移来简化除法的),\(R=B^r\)(\(r=5\)),\(T\)是要约简的数(\(0\le T<RN\))。另外设\(N'=-N^{-1}\pmod B\),这里模\(R\)也可以,因为计算的时候也会模\(B\)。我们称一个字的长度为4位,可以理解为一个存储单元是4位,更大位宽的数用数组表示。那么模数\(N\)占\(p=5\)个字(长度为5的数组),\(T\)占\(r+p=10\)个字。这里\(r\)恰好等于\(p\)是因为\(R\)刚好是比\(N\)大一点的2的幂,比如\(R\)取\(2^{24}\),\(r\)就不等于\(p\)了。
下面正式进入算法。算法一共分为\(r\)轮,每一轮(\(i\in[0,r)\))都是为了给\(T\)加上一定数量的\(N\),使得每一轮结束后\(T\)的后\(i+1\)个字全为0。抽象地说就是每一轮使得\(T+mN\equiv0\pmod{B^{i+1}}\)。这里非常重要,下面举例子说明:
- 第0轮:要使得\(T+mN\equiv0\pmod{B}\),则\(m=-TN^{-1}\bmod B=TN'\bmod B\),这里\(T\)只需取最低的字即可,即\(m=T[0]N'\bmod B\)。然后令\(T\leftarrow T+mN\),至此现在的\(T\)就能被\(B\)整除了,即\(T\)的最低字全为0。
- 第1轮:要使得\(T+mN\equiv0\pmod{B^2}\),显然这里加的\(mN\)应该是\(B\)的倍数(因为\(T\)目前已经是\(B\)的倍数了),所以我们明确写出来,即要使得\(T+mBN\equiv0\pmod{B^2}\)。根据同余的性质,等式和模数可以同时消去\(B\),即\(T[1:]+mN\equiv0\pmod B\)。所以\(m=-T[1:]N^{-1}\bmod B\),同样\(T\)只需取最低位,即\(m=T[1]N'\bmod B\)。然后令\(T\leftarrow T+mBN\),至此\(T\)就能被\(B^2\)整除了。
- 第2轮:同理计算\(m=T[2]N'\bmod B\),然后令\(T\leftarrow T+mB^2N\),至此\(T\)就能被\(B^3\)整除了。
- 第\(i\)轮:同理计算\(m=T[i]N'\bmod B\),然后令\(T\leftarrow T+mB^iN\),至此\(T\)就能被\(B^{i+1}\)整除了。
- 第\(r-1\)轮,\(T\)已经能被\(B^5\)即\(R\)整除了。
细节的部分直接上代码:
def reduction(T, N, p, R, r, B, k):
"""
:param T: 待约简的数(r + p个字)
:param N: 模数(p个字)
:param p: N的字数
:param R: R(r个字)
:param r: R的字数
:param B: B
:param k: B的位数
:return: T * R^{-1} % N
"""
assert B == 2 ** k
assert R == B ** r
NN = N # 备份一下
N_ = pow(-N, -1, B) # N' = -N^{-1} % B
# 将T和N表示成字的数组
T = [(T >> (i * k)) % B for i in range(r + p)]
N = [(N >> (i * k)) % B for i in range(p)]
T.append(0) # 在最高位加一个字,防止后面溢出,所以T一共是p+1个字
# 每一轮循环使得T的后i+1个字全为0
for i in range(r):
c = 0 # c是进位,后面会用到
m = T[i] * N_ % B # m = T[i] * N' % B
# 下面两个循环都是在计算:T = T + m * B^i * N
# 因为只能计算k位的乘法,所以只能依次用m去乘N[j]
# 至于乘B^i,直接跳过低位的字,即使用T[i+j]
for j in range(p): # 因为是与N相乘,N有p个字就循环p次
t = T[i + j] + m * N[j] + c # 计算中间结果
T[i + j] = t % B # 低位存储在当前字
c = t // B # 高位进位到下一个字
# 这个循环是为了处理多余的进位,只需处理到 r+p (闭区间)
for j in range(p, r + p - i + 1): # 即 i+j 从 p+i 循环到 r+p
t = T[i + j] + c
T[i + j] = t % B
c = t // B
# 除以R,即取前p+1个字
S = 0
for i in range(p + 1):
S = S + T[i + r] * (B ** i)
if S >= NN:
return S - NN
else:
return S
def main():
p = 5
N = 0xf0001
k = 4
B = 2 ** k
r = 5
R = B ** r
T = 0x123456789
print(hex(T * pow(R, -1, N) % N))
print(hex(reduction(T, N, p, R, r, B, k)))
if __name__ == '__main__':
main()完整算法
与上面几乎同理,只不过\(T\)没有提前算好,而是一边迭代一边计算。
def to_words(val, k, num_words):
"""将大整数转换为字数组 (Little Endian)"""
words = []
for _ in range(num_words):
words.append(val % (2 ** k))
val //= 2 ** k
return words
def from_words(words, k):
"""将字数组转回大整数"""
val = 0
for i, w in enumerate(words):
val += w * (2 ** (k * i))
return val
def word_based_MMM(X, Y, N, k, n):
"""
:param X: [0, N)
:param Y: [0, N)
:param N: 模数
:param k: 字长
:param n: R = 2^n, R末尾0的个数,一般也是模数的位数
"""
# 字的个数,s = ceil(n / k)
s = (n + k - 1) // k
# M' = -M^{-1} mod 2^k,模 2^n 也可
M_ = pow(-N, -1, 2 ** k)
# 将输入转为字数组
X_arr = to_words(X, k, s)
Y_arr = to_words(Y, k, s)
M_arr = to_words(N, k, s)
# 初始化 Z 数组,由于计算过程中会有进位,多加了1位
Z_arr = [0] * (s + 1)
# 外循环,遍历 X 的每一个字
for i in range(s):
# 计算 m
tmp_sum = Z_arr[0] + X_arr[i] * Y_arr[0]
m = (tmp_sum * M_) % (2 ** k)
carry = 0 # 将Ca和Cb合并了
# 内循环,将 j=0 也合并进去了
for j in range(s):
# 累加 X[i] * Y[j] 和 m * M[j]
sum_val = Z_arr[j] + X_arr[i] * Y_arr[j] + m * M_arr[j] + carry
# 右移 k
if j > 0:
Z_arr[j - 1] = sum_val % (2 ** k)
# 高 k 位作为进位传给下一次 j 循环
carry = sum_val // (2 ** k)
# 处理剩余的进位
sum_val = Z_arr[s] + carry
Z_arr[s - 1] = sum_val % (2 ** k)
Z_arr[s] = sum_val // (2 ** k)
# 将 Z 数组转回大整数
Z = from_words(Z_arr, k)
if Z >= N:
Z -= N
return Z
if __name__ == "__main__":
n = 20 # R = 2^n
k = 4 # 字长
N = 0xff001 # 模数
R = 2 ** n
X = 0x12345
Y = 0x56789
print((X * Y * pow(R, -1, N)) % N)
print(word_based_MMM(X, Y, N, k, n))