为了账号安全,请及时绑定邮箱和手机立即绑定

深度学习算子优化-FFT

背景

在数字信号和数字图像领域, 对频域的研究是一个重要分支。 我们日常“加工”的图像都是像素级,被称为是图像的空域数据。空域数据表征我们“可读”的细节。如果我们将同一张图像视为信号,进行频谱分析,可以得到图像的频域数据。 观察下面这组图 (来源),频域图中的亮点为低频信号,代表图像的大部分能量,也就是图像的主体信息。暗点为高频信号,代表图像的边缘和噪声。从组图可以看出,Degraded Goofy 与 Goofy 相比,近似的低频信号保留住了 Goofy 的“轮廓”,而其高频信号的增加使得背景噪点更加明显。频域分析使我们可以了解图像的组成,进而做更多的抽象分析和细节处理。

https://img1.sycdn.imooc.com//5acb3c8700013dc501600160.jpg

Goofy and Degraded Goofy

实现图像空域和频域转换的工具,就是傅立叶变换。由于图像数据在空间上是离散的,我们使用傅立叶变换的离散形式 DFT(Discrete Fourier Transform)及其逆变换 IDFT(Inverse Discrete Fourier Transform)。Cooley-Tuckey 在 DFT 的基础上,开发了更快的算法 FFT(Fast Fourier Transform)。https://img1.sycdn.imooc.com//60f5a4a00001c70102770118.jpg

DFT/FFT 在数字图像领域还有一些延伸应用。比如基于 DFT 的 DCT(Discrete Cosine Transform, 离散余弦变换)就用在了图像压缩 JPEG 算法 (来源) 和图像水印算法(来源)。

JPEG 编码是通过色彩空间转换、抽样分块、DCT 变换、量化编码实现的。其中 DCT 变换的使用将图像低频信息和高频信息区分开,在量化编码过程中压缩了少量低频信息、大量高频信息从而获得尺寸上压缩。从猫脸图上可看出随着压缩比增大画质会变差,但是主体信息还是得以保留。

https://img1.sycdn.imooc.com//5acb3c8700013dc501600160.jpg

猫脸不同 jpeg 画质(压缩比)

图像水印算法通过 DCT 将原图转换至频域,选取合适的位置嵌入水印图像信息,并通过 IDCT 转换回原图。这样对原图像的改变较小不易察觉,且水印通过操作可以被提取。https://img1.sycdn.imooc.com//5acb3c8700013dc501600160.jpg

DFT/FFT 在深度学习领域也有延伸应用。 比如利用 FFT 可以降低卷积计算量的特点,FFT_Conv 算法也成为常见的深度学习卷积算法。本文我们就来探究一下频域算法的原理和优化策略。

DFT 的原理及优化

公式

无论是多维的 DFT 运算,还是有基于 DFT 的 DCT/FFT_Conv, 底层的计算单元都是 DFT_1D。 因此,DFT_1D 的优化是整个 FFT 类算子优化的基础。 DFT_1D 的计算公式:

Xk=∑n=0N−1xne−j2πknNk=0,…,N−1X_{k}=\sum_{n=0}^{\mathrm{N}-1} x_{n} e^{-j 2 \pi k \frac{n}{N}} \quad k=0, \ldots, N-1Xk=∑n=0N−1xne−j2πkNnk=0,…,N−1

其中xnx_{n}xn 为长度为 N 的输入信号,e−j2πknNe^{-j 2 \pi k \frac{n}{N}}e−j2πkNn 是 1 的 N 次根,XkX_{k}Xk 为长度为 N 的输出信号。 该公式的矩阵形式为:

[X(0)X(1)⋮X(N−1)]=[WNnk][x(0)x(1)⋮x(N−1)]\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]⎣⎢⎢⎢⎢⎡X(0)X(1)⋮X(N−1)⎦⎥⎥⎥⎥⎤=[WNnk]⎣⎢⎢⎢⎢⎡x(0)x(1)⋮x(N−1)⎦⎥⎥⎥⎥⎤

单位复根的性质

DFT_1D 中的WNnk=e−j2πknNW_{N}^{nk} = e^{-j 2 \pi k \frac{n}{N}}WNnk=e−j2πkNn 是 1 的单位复根。直观地看,就是将复平面划分为 N 份,根据 k * n 的值逆时针扫过复平面的圆周。

https://img1.sycdn.imooc.com//5acb3c8700013dc501600160.jpg

单位复根有着周期性和对称性,我们依据这两个性质可以对 W 矩阵做大量的简化,构成 DFT_1D 的快速算法的基础。

周期性:WNk+N=WNkW_{N}^{k +N}=W_{N}^{k}WNk+N=WNk

对称性:WNk+N/2=−WNkW_{N}^{k+N / 2}=-W_{N}^{k}WNk+N/2=−WNk

Cooley-Tuckey FFT 算法

DFT_1D 的多种快速算法中,使用最频繁的是 Cooley-Tuckey FFT 算法。算法采用用分治的思想,将输入尺寸为 N 的序列,按照不同的基 radix,分解为 N/radix 个子序列,并对每个子序列再划分,直到不能再被划分为止。每一次划分都可以得到一级 stage,将所有的级自下而上组合在一起,计算得到最后的输出序列。 这里以 N = 8, radix=2 为例展示推理过程。 其中x(k)x(k)x(k)为 N=8 的序列,XF(k)X^{F}(k)XF(k)为 DFT 输出序列。 根据 DFT 的计算公式

XF(k)=W80x0+W8kx1+W82kx2+W83kx3+W84kx4+W85kx5+W86kx6+W87kx7X^{F}(k)=W_{8}^{0} x_{0}+W_{8}^{k} x_{1}+W_{8}^{2 k} x_{2}+W_{8}^{3k} x_{3}+W_{8}^{4k} x_{4} + W_{8}^{5k} x_{5}+W_{8}^{6k} x_{6} +W_{8}^{7k} x_{7}XF(k)=W80x0+W8kx1+W82kx2+W83kx3+W84kx4+W85kx5+W86kx6+W87kx7

根据奇偶项拆开,分成两个长度为 4 的序列G(k)G(k)G(k),H(k)H(k)H(k)。

XF(k)=W80x0+W82kx2+W84kx4+W86kx6X^{F}(k)= W_{8}^{0} x_{0}+W_{8}^{2 k} x_{2}+W_{8}^{4 k} x_{4}+W_{8}^{6 k} x_{6}XF(k)=W80x0+W82kx2+W84kx4+W86kx6

+W8k(W80x1+W82kx3+W84kx5+W86kx7)+W_{8}^{k}\left(W_{8}^{0} x_{1}+W_{8}^{2 k} x_{3}+W_{8}^{4 k} x_{5}+W_{8}^{6 k} x_{7}\right)+W8k(W80x1+W82kx3+W84kx5+W86kx7)

=GF(k)+W8kHF(k)=G^{F}(k)+W_{8}^{k} H^{F}(k)=GF(k)+W8kHF(k)

XF(k+4)=W80x0+W82(k+4)x2+W84(k+4)x4+W86(k+4)x6X^{F}(k+4)=W_{8}^{0} x_{0}+W_{8}^{2(k+4)} x_{2}+W_{8}^{4(k+4)} x_{4}+W_{8}^{6(k+4)} x_{6}XF(k+4)=W80x0+W82(k+4)x2+W84(k+4)x4+W86(k+4)x6

+W8(k+4)(W80x1+W82(k+4)x3+W84(k+4)x5+W86(k+4)x7)+W_{8}^{(k+4)}\left(W_{8}^{0} x_{1}+W_{8}^{2(k+4)} x_{3}+W_{8}^{4(k+4)} x_{5}+W_{8}^{6(k+4)} x_{7}\right)+W8(k+4)(W80x1+W82(k+4)x3+W84(k+4)x5+W86(k+4)x7)

=GF(k)+W8k+4HF(k)=G^{F}(k)+W_{8}^{k+4} H^{F}(k)=GF(k)+W8k+4HF(k)

=GF(k)−W8kHF(k)=G^{F}(k)-W_{8}^{k} H^{F}(k)=GF(k)−W8kHF(k)

GF(k)G^{F}(k)GF(k)和HF(k)H^{F}(k)HF(k)为G(k)G(k)G(k)和H(k)H(k)H(k)的 DFT 结果。GF(k)G^{F}(k)GF(k)和HF(k)H^{F}(k)HF(k)乘以对应的旋转因子W8kW_{8}^{k}W8k,进行简单的加减运算可以得到输出XF(k)X^{F}(k)XF(k)。 同理,对G(k)G(k)G(k)和H(k)H(k)H(k)也做一样的迭代,A(k)A(k)A(k),B(k)B(k)B(k),C(k)C(k)C(k),D(k)D(k)D(k) 都是 N=2 的序列,用他们的 DFT 结果进行组合运算可以得到GF(k)G^{F}(k)GF(k)和HF(k)H^{F}(k)HF(k)。

GF(k)=AF(k)+W4kBF(k)\begin{aligned} &G^{F}(k)=A^{F}(k) + W_{4}^{k}B^{F}(k)\\ \end{aligned}GF(k)=AF(k)+W4kBF(k)

GF(k+2)=AF(k)−W4kBF(k)\begin{aligned} &G^{F}(k+2)=A^{F}(k)-W_{4}^{k}B^{F}(k)\\ \end{aligned}GF(k+2)=AF(k)−W4kBF(k)

HF(k)=CF(k)+W4kDF(k)\begin{aligned} &H^{F}(k)=C^{F}(k)+W_{4}^{k}D^{F}(k)\\ \end{aligned}HF(k)=CF(k)+W4kDF(k)

HF(k+2)=CF(k)−W4kDF(k)\begin{aligned} &H^{F}(k+2)=C^{F}(k)-W_{4}^{k}D^{F}(k)\\ \end{aligned}HF(k+2)=CF(k)−W4kDF(k)

计算 N=2 的序列AF(k)A^{F}(k)AF(k),BF(k)B^{F}(k)BF(k),CF(k)C^{F}(k)CF(k),DF(k)D^{F}(k)DF(k), 因为k=0k=0k=0,旋转因子W20W_{2}^{0}W20= 1。只要进行加减运算得到结果。

[AF(0)AF(1)]=[111−1][x0x4]\left[\begin{array}{l} A^{F}(0) \\ A^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{0} \\ x_{4} \\ \end{array}\right][AF(0)AF(1)]=[111−1][x0x4]

[BF(0)BF(1)]=[111−1][x2x6]\left[\begin{array}{l} B^{F}(0) \\ B^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{2} \\ x_{6} \\ \end{array}\right][BF(0)BF(1)]=[111−1][x2x6]

[CF(0)CF(1)]=[111−1][x1x5]\left[\begin{array}{l} C^{F}(0) \\ C^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{1} \\ x_{5} \\ \end{array}\right][CF(0)CF(1)]=[111−1][x1x5]

[DF(0)DF(1)]=[111−1][x3x7]\left[\begin{array}{l} D^{F}(0) \\ D^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{3} \\ x_{7} \\ \end{array}\right][DF(0)DF(1)]=[111−1][x3x7]

用算法图形表示,每一层的计算会产生多个蝶形,因此该算法又被称为蝶形算法。 这里我们要介绍碟形网络的基本组成,对下文的分析有所帮助。https://img1.sycdn.imooc.com//5acb3c8700013dc501600160.jpg

N=8 碟形算法图

N=8 的计算序列被分成了 3 级,每一级 (stage) 有一个或多个块 (section),每个块中包含了一个或者多个蝶形(butterfly), 蝶形的计算就是 DFT 运算的 kernel。 每一个 stage 的计算顺序:

  • 取输入

  • 乘以转换因子

  • for section_num, for butterfly_num,执行 radixN_kernel

  • 写入输出。

看 N=8 的蝶形算法图,stage = 1 时,运算被分成了 4 个 section,每个 section 的 butterfly_num = 1。stage = 2 时,section_num = 2,butterfly_num = 2。 stage = 3 时,section_num = 1,butterfly_num = 4。 可以观察到,从左到右过程中 section_num 不断减少,butterfly_num 不断增加,蝶形群在“变大变密”,然而每一级总的碟形次数是不变的。 实际上,对于长度为 N,radix = r 的算法,我们可以推得到:

Sec\text SecSec _num\text numnum =N/rSN / r^ {S}N/rS

Butterfly\text ButterflyButterfly _num\text numnum  =rS−1r^{S-1}rS−1

Sec\text SecSec _stride=rS\text stride =r^{S}stride=rS

Butterfly\text ButterflyButterfly_stridestridestride=111

S 为当前的 stage,sec/butterfly_stride 是每个 section/butterfly 的间隔。复制代码

这个算法可以将复杂度从 O(n^2) 下降到 O(nlogn),显得高效而优雅。我们基于蝶形算法,对于不同的 radix 进行算法的进一步划分和优化,主要分为 radix - 2 的幂次的和 radix – 非 2 的幂次两类。

radix-2 的幂次优化

DFT_1D 的 kernel 即为矩阵形式中的WNnkW_{N}^{nk}WNnk矩阵,我们对 radix_2^n 的 kernel 进行分析。

背景里提到, DFT 公式的矩阵形式为:

[X(0)X(1)⋮X(N−1)]=[WNnk][x(0)x(1)⋮x(N−1)]\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]⎣⎢⎢⎢⎢⎡X(0)X(1)⋮X(N−1)⎦⎥⎥⎥⎥⎤=[WNnk]⎣⎢⎢⎢⎢⎡x(0)x(1)⋮x(N−1)⎦⎥⎥⎥⎥⎤

其中x(0)x(0)x(0) ~x(N−1)x(N-1)x(N−1)为乘以旋转因子WNknW_{N}^{kn}WNkn后的输入

当 radix = 2 时,由于W21W_{2}^1W21  = -1,W22W_{2}^2W22 = 1, radix_2 的 DFT 矩阵形式可以写为:

[XkXk+N/2]\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2}\end{array}\right][XkXk+N/2]=[111−1][WN0AkWNkBk]=\left[\begin{array}{cc}1 & 1 \\ 1 & -1\end{array}\right]\left[\begin{array}{l}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}}\end{array}\right]=[111−1][WN0AkWNkBk]

当 radix = 4 时,由于W41W_{4}^1W41 = -j,W42W_{4}^2W42 = -1,W43W_{4}^3W43  = j,W44W_{4}^4W44= 1,radix_4 的 DFT 矩阵形式可以写为:

[XkXk+N/4Xk+N/2Xk+3 N/4]=[11111−j−1j1−11−11j−1−j][WN0AkWNkBkWN2kCkWN3kDk]\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 4} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2} \\ \mathrm{X}_{\mathrm{k}+3 \mathrm{~N} / 4}\end{array}\right]=\left[\begin{array}{cccc}1 & 1 & 1 & 1 \\ 1 & -\mathrm{j} & -1 & \mathrm{j} \\ 1 & -1 & 1 & -1 \\ 1 & \mathrm{j} & -1 & -\mathrm{j}\end{array}\right]\left[\begin{array}{c}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{2 \mathrm{k}} \mathrm{C}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{3 \mathrm{k}} \mathrm{D}_{\mathrm{k}}\end{array}\right]⎣⎢⎢⎢⎡XkXk+N/4Xk+N/2Xk+3 N/4⎦⎥⎥⎥⎤=⎣⎢⎢⎢⎡11111−j−1j1−11−11j−1−j⎦⎥⎥⎥⎤⎣⎢⎢⎢⎡WN0AkWNkBkWN2kCkWN3kDk⎦⎥⎥⎥⎤

同理推得到 radix_8 的 kernel 为:

[111111111 W81−j W83−1−W81j−W831−j−1j1−j−1j1 W83j W81−1−W83−j−W811−11−11−11−11−W81−j−W83−1 W81j W831j−1−j1j−1−j1−W83j−W81−1 W83−j W81]\left[\begin{array}{cccccccc}1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & \mathrm{~W}_{8}^{1} & -j & \mathrm{~W}_{8}^{3} & -1 & -\mathrm{W}_{8}^{1} & j & -\mathrm{W}_{8}^{3} \\ 1 & -j & -1 & j & 1 & -j & -1 & j \\ 1 & \mathrm{~W}_{8}^{3} & j & \mathrm{~W}_{8}^{1} & -1 & -\mathrm{W}_{8}^{3} & -j & -\mathrm{W}_{8}^{1} \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 \\ 1 & -\mathrm{W}_{8}^{1} & -j & -\mathrm{W}_{8}^{3} & -1 & \mathrm{~W}_{8}^{1} & j & \mathrm{~W}_{8}^{3} \\ 1 & j & -1 & -j & 1 & j & -1 & -j \\ 1 & -\mathrm{W}_{8}^{3} & j & -\mathrm{W}_{8}^{1} & -1 & \mathrm{~W}_{8}^{3} & -j & \mathrm{~W}_{8}^{1}\end{array}\right]⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡111111111 W81−j W83−1−W81j−W831−j−1j1−j−1j1 W83j W81−1−W83−j−W811−11−11−11−11−W81−j−W83−1 W81j W831j−1−j1j−1−j1−W83j−W81−1 W83−j W81⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤

我们先来看访存,现代处理器对于计算性能的优化要优于对于访存的优化,在计算和访存相近的场景下, 访存通常是性能瓶颈。

DFT1D 中,对于不同基底的算法 r-2/r-4/r-8, 每一个 stage 有着相等的存取量:2 * butterfly_num * radix = 2N,  而不同的基底对应的 stage 数有着明显差异(log⁡2N\log_2Nlog2N vslog⁡4N\log_4Nlog4N vslog⁡8N\log_8Nlog8N)。

因此对于 DFT, 在不显著增加计算量的条件下, 选用较大的 kernel 会在访存上取得明显的优势。观察推导的 kernel 图, r-2 的 kernel 每个蝶形对应 4 次访存操作和,2 次复数浮点加减运算。r-4 的 kernel 每个蝶形算法对应 8 次 load/store、8 次复数浮点加减操作(合并相同的运算),在计算量略增加的同时 stage 由log⁡2N\log_2Nlog2N 下降到log⁡4N\log_4Nlog4N , 降低了总访存的次数, 因此会有性能的提升。r-8 的 kernel 每个蝶形对应 16 次 load/store、24 次复数浮点加法和 8 次浮点乘法。浮点乘法的存在使得计算代价有所上升, stage 由log⁡4N\log_4Nlog4N 进一步下降到log⁡8N\log_8Nlog8N ,但由于 N 日常并不会太大, r-4 到 r-8 的 stage 减少不算明显,所以优化有限

我们再来看计算的开销。减少计算的开销通常有两种办法:减少多余的运算、并行化。

以 r-4 算法为例,kernel 部分的计算为:

  • radix_4_first_stage(src, dst, sec_num, butterfly_num)

  • radix_4_other_stage(src, dst, sec_num, butterfly_num)

    • for butterfly_num

    • raidx_4_kernel

    • for Sec_num

radix4_first_stage 的数据由于 k=0, 旋转因子都为 1,可以省去这部分复数乘法运算,单独优化。 radix4_other_stage 部分, 从第 2 个 stage 往后, butterfly_num = 4^(s-1) 都为 4 的倍数,而每个 butterfly 数组读取/存储都是间隔的。可以对最里层的循环做循环展开加向量化,实现 4 个或更多 butterfly 并行运算。循环展开和 SIMD 指令的使用不仅可以提高并行性, 也可以提升 cacheline 利用的效率,可以带来较大的性能提升。 以 SM8150(armv8) 为例,r-4 的并行优化可以达到 r2 的 1.6x 的性能。

https://img1.sycdn.imooc.com//60f5a4a00001b4ca06840120.jpg

尺寸:1 * 2048(r2c) 环境:SM8150 大核

总之,对于 radix-2^n 的优化,选用合适的 radix 以减少多 stage 带来的访存开销,并且利用单位复根性质以及并行化降低计算的开销,可以带来较大的性能提升。

radix-非 2 的幂次优化

当输入长度 N = radix1^m1 * radix2^m2... 且 radix 都不为 2 的幂次时,如果使用 naive 的 O(n^2) 算法, 性能就会急剧下降。 常见的解决办法对原长补 0、使用 radix_N 算法、特殊的 radix_N 算法 (chirp-z transform)。补 0 至 2 的幂次方法对于大尺寸的输入要增加很多运算量和存储量, 而 chirp-z transform 是用卷积计算 DFT, 算法过于复杂。因此对非 2 的幂次 radix-N 的优化也是必要的。

radix-N 计算流程和 radix-2 幂次一样,我们同样可以利用单位复根的周期性和对称性,对 kernel 进行计算的简化。 以 radix-5 为例,radix-5 的 DFT_kernel 为:

[111111W51W52W5−2W5−11W52W5−1W51W5−21W5−2W51W5−1W521W5−1W5−2W52W51]\left[\begin{array}{cccc} 1&1&1&1&1\\ 1 &\mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{-1} \\ 1 &\mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{1} \\ \end{array}\right]⎣⎢⎢⎢⎢⎢⎡111111W51W52W5−2W5−11W52W5−1W51W5−21W5−2W51W5−1W521W5−1W5−2W52W51⎦⎥⎥⎥⎥⎥⎤

W5kW_5^kW5k 和W5−kW_{5}^{-k}W5−k在复平面上根据 x 轴对称,有相同的实部和相反的虚部。根据这个性质。如下图所示,对于每一个 stage,可以合并公共项 A,B,C,D,再根据公共项计算出该 stage 的输出。

A=(x1+x4)∗W51⋅r+(x2+x3)∗W52⋅rA=\left(x_{1}+x_{4}\right) * W_{5}^{1} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{2} \cdot rA=(x1+x4)∗W51⋅r+(x2+x3)∗W52⋅r

B=(−j)∗[(x1−x4)∗W51⋅i+(x2−x3)∗W52⋅i]B=(-j) * \left[\left(x_{1}-x_{4}\right) * W_{5}^{1} \cdot i+\left(x_{2}-x_{3}\right) * W_{5}^{2} \cdot i\right]B=(−j)∗[(x1−x4)∗W51⋅i+(x2−x3)∗W52⋅i]

C=(x1+x4)∗W52⋅r+(x2+x3)∗W51⋅rC=\left(x_{1}+x_{4}\right) * W_{5}^{2} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{1} \cdot rC=(x1+x4)∗W52⋅r+(x2+x3)∗W51⋅r

D=j∗[(x1−x4)∗W52⋅i−(x2−x3)∗W51⋅i]D=j * \left[\left(x_{1}-x_{4}\right) * W_{5}^{2} \cdot i-\left(x_{2}-x_{3}\right) * W_{5}^{1} \cdot i\right]D=j∗[(x1−x4)∗W52⋅i−(x2−x3)∗W51⋅i]

X(k)=x0+(x1+x4)+(x2+x3)\begin{array}{l} X(k)=x_{0}+\left(x_{1}+x_{4}\right)+\left(x_{2}+x_{3}\right)\\ \end{array}X(k)=x0+(x1+x4)+(x2+x3)

X(k+N/5)=x0+A−B\begin{array}{l} X(k+N/5)=x_{0}+\mathrm{A}-\mathrm{B}\\ \end{array}X(k+N/5)=x0+A−B

X(k+2N/5)=x0+C+D\begin{array}{l} X(k+2N/5)=x_{0}+\mathrm{C}+\mathrm{D}\\ \end{array}X(k+2N/5)=x0+C+D

X(k+3N/5)=x0+C−D\begin{array}{l} X(k+3N/5)=x_{0}+C-D\\ \end{array}X(k+3N/5)=x0+C−D

X(k+4N/5)=x0+A+B\begin{array}{l} X(k+4N/5)=x_{0}+\mathrm{A}+\mathrm{B}\\ \end{array}X(k+4N/5)=x0+A+B

这种算法减少了很多重复的运算。同时,在 stage>=2 的时候,同样对 butterfly 做循环展开加并行化,进一步减少计算的开销。 radix-5 的优化思想可以外推至 radix-N。对于 radix_N 的每一个 stage, 计算流程为:

  • 取输入

  • 乘以对应的转换因子

  • 计算公共项, radix_N 有 N-1 个公共项

  • 执行并行化的 radix_N_kernel

  • 写入输出

其他优化

上述两个章节描述的是 DFT_1D 的通用优化,在此基础上还可以做更细致的优化,可以参考本文引用的论文。

  • 对于全实数输入的, 由于输入的虚部为 0, 进行旋转因子以及 radix_N_kernel 的复数运算时会有多余的运算和多余的存储, 可以利用 split r2c 算法, 视为长度为 N/2 的复数序列, 计算 DFT 结果并进行 split 操作得到 N 长实数序列的结果。

  • 对于 radix-2 的幂次算法, 重新计算每个 stage 的输入/输出 stride 以取消第一级的位元翻转可以进一步减少访存的开销。

  • 对于 radix-N 算法, 在混合基框架下 N = radix1^m1 * radix2^m2, 合并较小的 radix 为大的 radix 以减少 stage。

DFT 延展算法的原理及优化

DCT 和 FFT_conv 两个典型的基于 DFT 延展的算法,DFT_1D/2D 的优化可以很好的用在这类算法中。

DCT

DCT 算法(Discrete Cosine Transform, 离散余弦变换)可以看作是 DFT 取其正弦分量并经过工业校正的算法。DFT_1D 的计算公式为:

X[k]=C(k)∑n=0N−1x[n]cos⁡((2n+1)πk2N)C(k)=1nk=1C(k)=2nk!=1\begin{aligned} X[k] &=C(k) \sum_{n=0}^{N-1} x[n] \cos \left(\frac{(2 n+1) \pi k}{2 N}\right) \\ &C(k)=\sqrt{\frac{1}{n}} \\&k=1 \\ &C(k)=\sqrt{\frac{2}{n}} \\&k!=1 \\ \end{aligned}X[k]=C(k)n=0∑N−1x[n]cos(2N(2n+1)πk)C(k)=n1k=1C(k)=n2k!=1

该算法 naive 实现是 O(n^2) 的,而我们将其转换成 DFT_1D 算法,可以将算法复杂度降至 O(nlogn)。 基于 DFT 的 DCT 算法流程为:

  • 对于 DCT 的输入序列 x[n], 创建长为 2N 的输入序列 y[n] 满足 y[n] = x[n] + x[2N-n-1], 即做一个镜像对称。

  • 对输入序列 y[n] 进行 DFT 运算,得到输出序列 Y[K]。

  • 由 Y[K] 计算得到原输入序列的输出 X[K] 。

我们尝试推导一下这个算法:

y[n]=x[n]+x[2N−1−n]\begin{array}{l} y[n]=x[n]+x [2 N-1-n] \\  \end{array}y[n]=x[n]+x[2N−1−n]

Y[k]=∑n=0N−1x[n]⋅e−j2πkn2N+∑n=N2N−1x[2N−1−n]⋅e−j2πkn2N\begin{array}{l} Y[k]=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=N}^{2 N-1} x[2 N-1-n] \cdot e^{-j \frac{2 \pi k n}{2 N}} \end{array}Y[k]=∑n=0N−1x[n]⋅e−j2N2πkn+∑n=N2N−1x[2N−1−n]⋅e−j2N2πkn

=∑n=0N−1x[n]⋅e−j2πkn2N+∑n=0N−1x[n]⋅e−j2πk(2N−1−n)2N\begin{array}{l} \\=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=0}^{N-1} x[n] \cdot e^{-j \frac{2 \pi k (2N-1-n)}{2 N}} \end{array}=∑n=0N−1x[n]⋅e−j2N2πkn+∑n=0N−1x[n]⋅e−j2N2πk(2N−1−n)

=e−j2πk2N⋅∑n=0N−1x[n](e−j2π2Nkn⋅e−jπ2Nk+ej2π2Nkn⋅ejπ2Nk)\begin{array}{l} \\=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] (e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k}) \end{array}=e−j2N2πk⋅∑n=0N−1x[n](e−j2N2πkn⋅e−j2Nπk+ej2N2πkn⋅ej2Nπk)

=e−j2πk2N⋅∑n=0N−1x[n]⋅2⋅cos⁡(2n+12Nkπ)\begin{array}{l} \\=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] \cdot 2\cdot\cos(\frac{2n+1}{2N} k\pi) \end{array}=e−j2N2πk⋅∑n=0N−1x[n]⋅2⋅cos(2N2n+1kπ)

=e−j2πk2N⋅C(u)⋅X[k]\begin{array}{l} \\=e^{-j \frac{2 \pi k }{2 N}} \cdot C(u) \cdot X[k] \end{array}=e−j2N2πk⋅C(u)⋅X[k]

对 y[n] 依照 DFT 公式展开,整理展开的两项并提取公共项e−j2πk2Ne^{-j \frac{2 \pi k }{2 N}}e−j2N2πk, 根据欧拉公式和诱导函数,整理非公共项(e−j2π2Nkn⋅e−jπ2Nk+ej2π2Nkn⋅ejπ2Nk)(e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k})(e−j2N2πkn⋅e−j2Nπk+ej2N2πkn⋅ej2Nπk)。可以看出得到的结果正是 x[k] 和与 k 有关的系数的乘积。这样就可以通过先计算Y[k]Y[k]Y[k]得到 x[n] 的 DCT 输出X[k]X[k]X[k] 。

在理解算法的基础上,我们对 DFT_1D 的优化可以完整地应用到 DCT 上。DCT_2D 的计算过程是依次对行、列做 DCT_1D, 我们用多线程对 DCT_1D 进行并行,可以进一步优化算法。

FFT_conv

Conv 是深度学习最常见的运算,计算 conv 常用的方法有 IMG2COL+GEMM, Winograd, FFT_conv。三种算法都有各自的使用场景。

FFT_conv 的数学原理是时域中的循环卷积对应于其离散傅里叶变换的乘积。如下图所示, f 和 g 的卷积等同于将 f 和 g 各自做傅立叶变幻 F,进行点乘并通过傅立叶逆变换计算后的结果。

f∗ Circ g=F−1(F(f)⋅F(g))f \underset{\text { Circ }}{*} g=\mathcal{F}^{-1}(\mathcal{F}(f) \cdot \mathcal{F}(g))f Circ ∗g=F−1(F(f)⋅F(g))

直观的理论证明可下图(来源)。

F[f∗g]{F}[f * g]F[f∗g]

=∫−∞∞[(∫−∞∞g(z)f(x−z)dz)e−ikx]dx=\int_{-\infty}^{\infty}\left[\left(\int_{-\infty}^{\infty}g(z)f(x-z)dz\right)e^{-i k x}\right]dx=∫−∞∞[(∫−∞∞g(z)f(x−z)dz)e−ikx]dx

=∫−∞∞g(z)[∫−∞∞f(x−z)e−ikxdx]dz=\int_{-\infty}^{\infty} g(z)\left[\int_{-\infty}^{\infty} f(x-z) e^{-i k x} d x\right] d z=∫−∞∞g(z)[∫−∞∞f(x−z)e−ikxdx]dz

=∫−∞∞g(z)[∫−∞∞f(y)e−ik(y+z)dy]dz=\int_{-\infty}^{\infty} g(z)\left[\int_{-\infty}^{\infty} f(y) e^{-i k(y+z)} d y\right] d z=∫−∞∞g(z)[∫−∞∞f(y)e−ik(y+z)dy]dz

=[∫−∞∞g(z)e−ikzdz][∫−∞∞f(y)e−ikydy]=\left[\int_{-\infty}^{\infty} g(z) e^{-i k z} d z\right]\left[\int_{-\infty}^{\infty} f(y) e^{-i k y} d y\right]=[∫−∞∞g(z)e−ikzdz][∫−∞∞f(y)e−ikydy]

=F[f]⋅F[g]=\mathcal{F}[f] \cdot \mathcal{F}[g]=F[f]⋅F[g]

将卷积公式和离散傅立叶变换展开, 改变积分的顺序并且替换变量, 可以证明结论。 注意这里的卷积是循环卷积, 和我们深度学习中常用的线性卷积是有区别的。 利用循环卷积计算线性卷积的条件为循环卷积长度 L⩾| f |+| g |−1。 因此我们要对 Feature Map 和 Kernel 做 zero-padding,并从最终结果中取有效的线性计算结果。

FFT_conv 算法的流程:

  • 将 Feature Map 和 Kernel 都 zero-pad 到同一个尺寸,进行 DFT 转换。 

  • 矩阵点乘

  • 将计算结果通过 IDFT 计算出结果。

该算法将卷积转换成点乘, 算法复杂度是 O(nlogn), 小于卷积的 O(n^2), 在输入的尺寸比较大时可以减少运算量,适用于大 kernel 的 conv 算法。

深度学习计算中, Kernel 的尺寸要远小于 Feature Map, 因此 FFT_conv 第一步的 zero-padding 会有很大的开销,参考论文 2 里提到可以通过对 Feature map 进行分块, 分块后的 Feature Map 和 Kernel 需要 padding 到的尺寸较小,可以大幅减小这一部分的开销。 优化后 fft_conv 的计算流程为:

  • 合理安排缓存计算出合适的 tile 尺寸,对原图进行分块

  • 分块后的小图和 kernel 进行 zero-padding, 并进行 DFT 运算

  • 小图矩阵点乘

  • 进行逆运算并组合成大图。

同时我们可以观察到,FFT_conv 的核心计算模块还是针对小图的 DFT 运算, 因此我们可以将前一章节对 DFT 的优化代入此处,辅以多线程,进一步提升 FFT_Conv 的计算效率。


作者:MegEngine
链接:https://juejin.cn/post/6986542310905348127
来源:掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消