- 并行編程方法與優(yōu)化實(shí)踐
- 劉文志
- 1793字
- 2019-01-01 01:08:32
1.3.4 二維單通道圖像離散卷積
在現(xiàn)在流行的深度神經(jīng)網(wǎng)絡(luò)做圖像識別的應(yīng)用中,有一種關(guān)鍵的運(yùn)算是計(jì)算二維圖像卷積。
這里我們只考慮單通道和單精度浮點(diǎn)類型的二維離散卷積,輸入是一個由浮點(diǎn)的二維數(shù)組保存的圖像,還有一個m×n的模板kernel,m和n一般相等,常用的大小有3、5、7…等。設(shè)圖像大小是M×N,則浮點(diǎn)計(jì)算次數(shù)是2×(M-m+1)×(N-n+1)×m×n。如果m和n比較大,則這是一個典型的計(jì)算訪存比很高的計(jì)算過程。經(jīng)過精心調(diào)優(yōu),應(yīng)該可以充分利用CPU的浮點(diǎn)運(yùn)算性能進(jìn)行優(yōu)化。
這里主要的難點(diǎn)是如何利用SIMD數(shù)據(jù)并行計(jì)算,并做好指令流水調(diào)度。對于4×4或者8×8這樣的kernel,可以直接對kernel做SIMD向量化,但這種方法限制太多,對于非4倍大小的kernel會有冗余和截?cái)?,指令流水長度也會受限,所以需要換一個并行的維度。
一般情況下,輸入圖像和輸出圖像的大小不會太小,可以在輸出圖像上直接做寄存器分塊。SSE指令的向量寄存器長度是4個float,AVX或FMA指令是8個float。前面章節(jié)討論過,SSE和AVX的乘加延遲總和都是8個周期,F(xiàn)MA是10個周期。綜上,對于SSE指令,可以構(gòu)造8×4的寄存器分塊;AVX指令可以構(gòu)造8×8的分塊;FMA指令可以構(gòu)造10×8的分塊。
下面我們就以AVX指令為例來說明如何計(jì)算一個寄存器分塊。這里以3×3的kernel為例。
X64指令集有16個向量寄存器,使用8個向量寄存器保存當(dāng)前計(jì)算的8×8個輸出結(jié)果,首先初始化為0;從左到右每次從kernel中的取出一列值,這里一次取3個,各廣播到一個向量寄存器中;從輸入矩陣中每次讀取一個向量,和前面廣播的3個kernel值向量做乘法,并累加到輸出矩陣當(dāng)前的寄存器分塊中;再從下一行同樣的偏移位置讀一個向量,做相同的操作,直到第8+3-1=10行為止。圖1-6展示了計(jì)算到kernel第二列的值的乘累加位置,其中輸入矩陣?yán)锷罨疑脑乇硎?×3的kernel需要額外讀入的2行2列的數(shù)據(jù)(3-1=2)。
圖1-6 單通道二維卷積的向量寄存器分塊算法
一個8×8的寄存器分塊計(jì)算完畢,就可以寫回到輸出矩陣,并可以開始計(jì)算右邊緊挨著的一個新的寄存器分塊。不斷推進(jìn)這個過程,直到輸出矩陣所有位置被計(jì)算完畢。這個計(jì)算過程對cache的重用已經(jīng)非常好了,不需要再做專門的局部化處理。對于邊界不滿足8×8的分塊,可以用宏或者代碼生成器生成小于8×8的寄存器分塊計(jì)算過程。
下面的代碼片段展示了一個8×8寄存器分塊和3×3大小的kernel計(jì)算一列kernel的過程。
ymm13 = _mm256_broadcast_ss(flt + flt_w * 0 + n); // 從kernel中 ymm14 = _mm256_broadcast_ss(flt + flt_w * 1 + n); // 讀入一列元素 ymm15 = _mm256_broadcast_ss(flt + flt_w * 2 + n); // 廣播到3個向量 ymm10 = _mm256_loadu_ps(src + src_w * 0); // 讀取輸入數(shù)組第0個向量 ymm11 = _mm256_mul_ps(ymm13, ymm10); // 第0個向量只跟kernel最上面元素做乘加 ymm0 = _mm256_add_ps(ymm11, ymm0); ymm10 = _mm256_loadu_ps(src + src_w * 1); // 讀取輸入數(shù)組第1個向量 ymm11 = _mm256_mul_ps(ymm13, ymm10); // 第1個向量跟kernel前兩個元素做乘加 ymm1 = _mm256_add_ps(ymm11, ymm1); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm0 = _mm256_add_ps(ymm11, ymm0); ymm10 = _mm256_loadu_ps(src + src_w * 2); // 讀取輸入數(shù)組第2個向量 ymm11 = _mm256_mul_ps(ymm13, ymm10); // 從第2個向量開始跟全部3個kernel元素做乘加 ymm2 = _mm256_add_ps(ymm11, ymm2); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm1 = _mm256_add_ps(ymm11, ymm1); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm0 = _mm256_add_ps(ymm11, ymm0); ymm10 = _mm256_loadu_ps(src + src_w * 3); ymm11 = _mm256_mul_ps(ymm13, ymm10); ymm3 = _mm256_add_ps(ymm11, ymm3); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm2 = _mm256_add_ps(ymm11, ymm2); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm1 = _mm256_add_ps(ymm11, ymm1); ymm10 = _mm256_loadu_ps(src + src_w * 4); ymm11 = _mm256_mul_ps(ymm13, ymm10); ymm4 = _mm256_add_ps(ymm11, ymm4); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm3 = _mm256_add_ps(ymm11, ymm3); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm2 = _mm256_add_ps(ymm11, ymm2); ymm10 = _mm256_loadu_ps(src + src_w * 5); ymm11 = _mm256_mul_ps(ymm13, ymm10); ymm5 = _mm256_add_ps(ymm11, ymm5); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm4 = _mm256_add_ps(ymm11, ymm4); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm3 = _mm256_add_ps(ymm11, ymm3); ymm10 = _mm256_loadu_ps(src + src_w * 6); ymm11 = _mm256_mul_ps(ymm13, ymm10); ymm6 = _mm256_add_ps(ymm11, ymm6); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm5 = _mm256_add_ps(ymm11, ymm5); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm4 = _mm256_add_ps(ymm11, ymm4); ymm10 = _mm256_loadu_ps(src + src_w * 7); ymm11 = _mm256_mul_ps(ymm13, ymm10); ymm7 = _mm256_add_ps(ymm11, ymm7); ymm11 = _mm256_mul_ps(ymm14, ymm10); ymm6 = _mm256_add_ps(ymm11, ymm6); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm5 = _mm256_add_ps(ymm11, ymm5); ymm10 = _mm256_loadu_ps(src + src_w * 8); ymm11 = _mm256_mul_ps(ymm14, ymm10); // 倒數(shù)第2個向量只跟kernel最后兩個元素做乘加 ymm7 = _mm256_add_ps(ymm11, ymm7); ymm11 = _mm256_mul_ps(ymm15, ymm10); ymm6 = _mm256_add_ps(ymm11, ymm6); ymm10 = _mm256_loadu_ps(src + src_w * 9); // 讀取輸入數(shù)組最后一個向量 ymm11 = _mm256_mul_ps(ymm15, ymm10); // 最后一個向量只跟kernel最后一個元素做乘加 ymm7 = _mm256_add_ps(ymm11, ymm7);
下面來分析一下這個算法的cache效率。對于一個8×8的分塊,總的浮點(diǎn)計(jì)算量是8×8×3×3×2(乘加各算一次浮點(diǎn)操作)=1152;訪問cache的數(shù)量(這里只以讀取float個數(shù)計(jì))是(8×(8+3-1)+3)×3=243;經(jīng)過測試,支持AVX指令的SNB架構(gòu)CPU的cache帶寬是一個周期8個float,同時SNB架構(gòu)一個周期可以發(fā)射8個單精度浮點(diǎn)乘法和單精度浮點(diǎn)加法,即一個周期內(nèi)cache和浮點(diǎn)計(jì)算的吞吐比例是1∶2;但因?yàn)榇蟛糠謈ache訪問并非對齊到向量長度,我們保守估計(jì)cache帶寬會損失一半,即降到1∶4,但仍然高于我們前面計(jì)算過的243∶1152。綜上,cache延遲可以被浮點(diǎn)計(jì)算掩蓋,帶寬足夠滿足浮點(diǎn)計(jì)算的峰值性能。
但由于cache訪問量的減少會導(dǎo)致流水線有部分依賴,這個算法仍不能達(dá)到浮點(diǎn)峰值性能,表1-10展示了不同尺寸的kernel實(shí)測的浮點(diǎn)性能(在Haswell架構(gòu)下使用FMA指令),以及與浮點(diǎn)峰值的比較??梢钥闯?,kernel尺寸越大,計(jì)算訪存比越高,峰值性能也越高,最高可以達(dá)到大約66%的浮點(diǎn)峰值性能。
表1-10 二維離散單通道卷積的性能測試
- Raspberry Pi for Python Programmers Cookbook(Second Edition)
- 軟件測試項(xiàng)目實(shí)戰(zhàn)之性能測試篇
- Python零基礎(chǔ)快樂學(xué)習(xí)之旅(K12實(shí)戰(zhàn)訓(xùn)練)
- Java面向?qū)ο蟪绦蜷_發(fā)及實(shí)戰(zhàn)
- Mastering Ubuntu Server
- Symfony2 Essentials
- MATLAB for Machine Learning
- Learning OpenCV 3 Computer Vision with Python(Second Edition)
- Learning Continuous Integration with TeamCity
- Building Microservices with .NET Core
- JavaScript應(yīng)用開發(fā)實(shí)踐指南
- Node.js 12實(shí)戰(zhàn)
- Distributed Computing in Java 9
- Python 3 Object:oriented Programming(Second Edition)
- Python程序設(shè)計(jì)教程