【09】FFT解大整数乘法

1. 代码文件说明

提交的代码包括两个文件夹,fft 和 judge。fft 文件夹包括的是可独立运行的 fft 求大数乘法的程序(做了输入输出重定向),以及实际提交于 Leetcode 的代码。judge 文件夹包含了三个数据文件(test.data 为十组本地测试数据,ans.data 为正确答案,fftans.data 为 fft.exe 的运行输出结果),两个 python 文件(data.py 为生成测试数据和答案的脚本,judge.py 为判断结果是否正确的脚本)。

代码放在文末

2. 求解思想

  • 系数表示法

在计算乘法的时候,我们可以把一个数字分解为一个多项式 $A(x) = \sum^{n-1}_{i=0}\alpha_ix^i$ 。对于一个固定的 $x$,我们就可以把两个数字分解为 $n$, $m$ 维的两个向量(两个 $x$ 进制数),则乘法结果就是对应的 $n$ 维和 $m$ 维的向量的卷积结果(多项式系数的卷积)。根据 卷积定理

向量卷积的离散傅里叶变换 是 向量离散傅里叶变换的乘积。

我们就可以把两个大数分解为向量 $a$,$b$,然后对两个向量分别做离散傅里叶变换,之后逐位相乘得到一个新向量 $c$。对 $c$ 做逆离散傅里叶变换,再进行进位,就可以得到结果。

快速傅里叶变换可以做到在 $O(NlogN)$ 的复杂度完成原本 $O(N^2)$ 复杂度的离散傅里叶变换,让整个算法的复杂度也降低到 $O(NlogN)$ 。

  • 点值表示法

卷积定理 进一步进行展开解释。在两个整数分解为 $A(x)$ 的时候,我们可以分别对于 $n$ 个不同的 $x$ 值记录 $n$ 个点值对 $(x_i,A_a(x_i))$ 和 $(x_i,A_b(x_i))$ 。那么我们也可以根据这两组点值对直接得到乘法结果对应的一组点值对 $(x_i,A_a(x_i)\times A_b(x_i))$ 。

多项式插值的唯一性定理 证明了这 $n$ 个点值对即可用来表示(还原)对应的多项式。由于乘法结果会有 $2n$ 位,所以我们一开始分解时实际分别记录两个数对应的 $2n$ 个点值对,最终可以使用插值函数的方法恢复出 $2n$ 位的乘法结果。

这和之前系数表示法那一段之间的联系就在于,离散傅里叶变换 DFT 就是一个计算系数表达式与点值表达式之间互换的算法(采样)。只不过实际计算过程中,DFT 选取的 $x$ 值是复数值(n次单位根)。

直接使用 DFT 计算的话,转换步骤的复杂度显然需要 $O(n^2)$ 。而 FFT 能够快速完成系数表达式和点值表达式之间的换算,使复杂度降低到 $O(NlogN)$ 。

  • FFT 计算过程

结合 消去引理折半引理求和引理,我们可以分治的求解 DFT:

一个界长为 $N$ 的离散傅里叶变换可以重新写成两个界长各为 $N/2$ 的离散傅里叶变换之和。其中一个变换由原来 $N$ 个点中的偶数点构成,另一个变换由奇数点构成。这个过程递归进行下去,直到将全部数据细分为界长为 1 的变换。在边界上,界长为 1 的变换等于自身。

为了方便起见,在反转置换之前先补充前导零,使向量长度为 2 的幂形式。然后对向量进行裂项,即反转变换。将一个 2 的 n 次幂长度的向量进行裂项操作,每个元素的位置就会是下标的二进制反转之后再转换成十进制的位置,可以采用 Rader 算法(二进制平摊反转置换算法)实现,这样可以将递归转化成迭代执行。

最终在回溯的时候,不断套用公式执行蝶形运算,就可以得到结果了。

由于递归执行了求解过程,递归树为一棵完全二叉树,所以易知复杂度为 $O(NlogN)$ 。

  • 逆 FFT

只需对原来的 FFT 算法代码进行小小的修改:a 和 y 互换,$ω_n^{-1}$ 代替 $ω_n$ ,最后所有值除以 $n$ 即可。

  • 全部程序执行流程:

读入两个大数字符串,分别转化为数组一位一位存储,并补充前导零。分别进行 FFT 计算,按位相乘,然后做 IFFT 计算。最后将结果进行进位操作,就能得到最终的乘法结果。

3. 代码执行结果

  • 本地数据

共生成了 10 组 100 $\times$ 100 位的随机数。

示例:(第一组大数)

7739385993211797423647071118580282469713569881037743170530795280641276969768173826242862186300508114 9672516198036485560430536046045403561144663114397844686576323489397779756322778671971277864423561283

得到结果:

74859336382197804460271536901670667372336115116031816158713680432079139311256687277234225158377643751548229859235061475750266007841281857699781669159860971180777171200406847628409854302216736317750262

使用 Python 脚本判断十组数据运行正确性:

1544710393052

通过测试。

Snipaste_2018-12-13_22-17-07

Snipaste_2018-12-13_22-23-15

通过。

P.S. Leetcode 上这道题最快的 4 ms 过题代码使用的是最朴素的 $O(n^2)$ 逐位相乘然后进位的算法。我的代码慢了 50 多倍,这说明 FFT 的常数大概是非常大了(至少我这个代码是)。

代码

fft/fft.cpp

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <cstdio>
#include <iostream>
#include <cmath>
#include <cstring>
#include <algorithm>

using namespace std;

const double PI = acos(-1.0);

// 复数结构体
struct Complex{
    double x, y;     // 实部和虚部 x + yi
    Complex(double _x = 0.0, double _y = 0.0){
        x = _x;
        y = _y;
    }
    Complex operator - (const Complex &b) const{
        return Complex(x - b.x, y - b.y);
    }
    Complex operator + (const Complex &b) const{
        return Complex(x + b.x, y + b.y);
    }
    Complex operator * (const Complex &b) const{
        return Complex(x*b.x - y*b.y, x*b.y + y*b.x);
    }
};

// FFT 和 IFFT 前的反转变换。
// 位置 i 和 (i 二进制反转后位置)互换
// len 必须是 2 的幂
void change(Complex y[], int len){
    int i, j, k;
    for(i = 1, j = len / 2; i < len - 1; i++){
        if(i < j) swap(y[i], y[j]);
        // 交换互为小标反转的元素,i < j 保证交换一次
        // i 做正常的 +1,j 左反转类型的 +1,始终保持 i 和 j 是反转的
        k = len / 2;
        while(j >= k){
            j -= k;
            k /= 2;
        }
        if(j < k) j += k;
    }
}

// FFT
// len 必须为 2 的幂,
// on == 1 时是 DFT,on == -1 时是 IDFT
void fft(Complex y[], int len, int on){
    change(y, len);
    for(int h = 2; h <= len; h *= 2){
        Complex wn(cos(-on * 2*PI / h), sin(-on * 2*PI / h));
        for(int j = 0; j < len; j += h){
            Complex w(1, 0);
            for(int k = j; k < j + h/2; k++){
                Complex u = y[k];
                Complex t = w * y[k + h/2];
                y[k] = u + t;
                y[k + h/2] = u - t;
                w = w * wn;
            }
        }
    }
    if(on == -1)
        for(int i = 0; i < len; i++)
            y[i].x /= len;
}

const int MAXN = 200010;
Complex x1[MAXN], x2[MAXN];
char str1[MAXN / 2], str2[MAXN / 2];
int sum[MAXN];

int main(){
    freopen("test.data", "r", stdin);
    freopen("fftans.data", "w", stdout);

    // 输入两个大数字符串
    while(scanf("%s%s", str1, str2)!=EOF){
        int len1 = strlen(str1);
        int len2 = strlen(str2);

        int len = 1;
        while(len < len1 * 2 || len < len2 * 2) len *= 2;

        // 把字符串转化为复数数组并补前导 0 至 2 的幂
        for(int i = 0; i < len1; i++)
            x1[i] = Complex(str1[len1 - 1 - i] - '0', 0);
        for(int i = len1; i < len; i++)
            x1[i] = Complex(0, 0);

        for(int i = 0; i < len2; i++)
            x2[i] = Complex(str2[len2 - 1 - i] - '0', 0);
        for(int i = len2; i < len; i++)
            x2[i] = Complex(0, 0);
        
        // 求DFT
        fft(x1, len, 1);
        fft(x2, len, 1);

        // 相乘
        for(int i = 0; i < len; i++)
            x1[i] = x1[i] * x2[i];
        
        // IDFT
        fft(x1, len, -1);

        // 转化为整数值
        for(int i = 0; i < len; i++)
            sum[i] = (int)(x1[i].x + 0.5);
        for(int i = 0; i < len; i++){
            sum[i+1] += sum[i] / 10;
            sum[i] %= 10;
        }

        // 输出结果
        len = len1 + len2 - 1;
        while(sum[len] <= 0 && len > 0) len--;
        for(int i = len; i >= 0; i--)
            printf("%c", sum[i] + '0');
        printf("\n");
    }
    return 0;
}

fft/leetcode_upload.cpp

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class Solution {
public:
    double PI = acos(-1.0);
    // 复数结构体
    struct Complex{
        double x, y;     // 实部和虚部 x + yi
        Complex(double _x = 0.0, double _y = 0.0){
            x = _x;
            y = _y;
        }
        Complex operator - (const Complex &b) const{
            return Complex(x - b.x, y - b.y);
        }
        Complex operator + (const Complex &b) const{
            return Complex(x + b.x, y + b.y);
        }
        Complex operator * (const Complex &b) const{
            return Complex(x*b.x - y*b.y, x*b.y + y*b.x);
        }
    };
    // FFT 和 IFFT 前的反转变换。
    // 位置 i 和 (i 二进制反转后位置)互换
    // len 必须是 2 的幂
    void change(Complex y[], int len){
        int i, j, k;
        for(i = 1, j = len / 2; i < len - 1; i++){
            if(i < j) swap(y[i], y[j]);
            // 交换互为小标反转的元素,i < j 保证交换一次
            // i 做正常的 +1,j 左反转类型的 +1,始终保持 i 和 j 是反转的
            k = len / 2;
            while(j >= k){
                j -= k;
                k /= 2;
            }
            if(j < k) j += k;
        }
    }

    // FFT
    // len 必须为 2 的幂,
    // on == 1 时是 DFT,on == -1 时是 IDFT
    void fft(Complex y[], int len, int on){
        change(y, len);
        for(int h = 2; h <= len; h *= 2){
            Complex wn(cos(-on * 2*PI / h), sin(-on * 2*PI / h));
            for(int j = 0; j < len; j += h){
                Complex w(1, 0);
                for(int k = j; k < j + h/2; k++){
                    Complex u = y[k];
                    Complex t = w * y[k + h/2];
                    y[k] = u + t;
                    y[k + h/2] = u - t;
                    w = w * wn;
                }
            }
        }
        if(on == -1)
            for(int i = 0; i < len; i++)
                y[i].x /= len;
    }

    Complex x1[200010], x2[200010];
    char str1[100005], str2[100005];
    int sum[200010];

    string multiply(string num1, string num2) {
        // 输入两个大数字符串
        strcpy(str1, num1.c_str());
        strcpy(str2, num2.c_str());

        int len1 = strlen(str1);
        int len2 = strlen(str2);

        int len = 1;
        while(len < len1 * 2 || len < len2 * 2) len *= 2;

        // 把字符串转化为复数数组
        for(int i = 0; i < len1; i++)
            x1[i] = Complex(str1[len1 - 1 - i] - '0', 0);
        for(int i = len1; i < len; i++)
            x1[i] = Complex(0, 0);
        for(int i = 0; i < len2; i++)
            x2[i] = Complex(str2[len2 - 1 - i] - '0', 0);
        for(int i = len2; i < len; i++)
            x2[i] = Complex(0, 0);
        
        // 求DFT
        fft(x1, len, 1);
        fft(x2, len, 1);

        // 相乘
        for(int i = 0; i < len; i++)
            x1[i] = x1[i] * x2[i];
        
        // IDFT
        fft(x1, len, -1);

        // 转化为整数值
        for(int i = 0; i < len; i++)
            sum[i] = (int)(x1[i].x + 0.5);
        for(int i = 0; i < len; i++){
            sum[i+1] += sum[i] / 10;
            sum[i] %= 10;
        }

        // 输出结果
        len = len1 + len2 - 1;
        while(sum[len] <= 0 && len > 0) len--;
        string ans;
        for(int i = len; i >= 0; i--)
            ans.append(1, char(sum[i] + '0'));
        return ans;
    }
};

judge/data.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from random import randint

test = open("test.data", 'w')
ans = open("ans.data", 'w')

last = str()
temp = str()

for i in range(20):
    last = temp
    temp = str(randint(1, 9))
    for j in range(99):
        temp += str(randint(0, 9))
    test.write(temp)
    if i % 2 == 1:
        test.write("\n")
        ans.write(str(int(temp) * int(last)) + "\n")
    else:
        test.write(" ")

test.close()

judge/judge.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
fftans = open("fftans.data", "r").readlines()
ans = open("ans.data", "r").readlines()

if len(fftans) != len(ans):
    print("Fail")

for i in range(len(fftans)):
    if fftans[i] != ans[i]:
        print("Fail ! No. {}".format(i + 1))
        break
print("Success !")

Courses

3029 Words

2018-12-13 22:30 +0800

本文阅读量
本站访客量