前兩天為了加速一段求梯度的代碼,用了SSE指令,在實(shí)驗(yàn)室PMH大俠的指導(dǎo)下,最終實(shí)現(xiàn)了3倍速度提升(極限是4倍,因?yàn)?個(gè)浮點(diǎn)數(shù)一起計(jì)算)。在這里寫(xiě)一下心得,歡迎拍磚。
SSE加速的幾個(gè)關(guān)鍵是
(1) 用于并行計(jì)算的數(shù)據(jù)結(jié)構(gòu)要16字節(jié)對(duì)齊
(2) 直接寫(xiě)匯編,不要用SSE的Load Store指令
(3) 對(duì)于SSE本身不提供的三角函數(shù)等指令,可以用查表法,但要用SSE來(lái)算索引號(hào)
相比起用GPU加速來(lái)說(shuō),SSE的并行性要低一些,而且提供的指令,功能函數(shù)也要少,但是使用起來(lái)相對(duì)要簡(jiǎn)單一些,而且也不存在紋理傳送進(jìn)出顯存的overhead。
原先的代碼是這樣的:
1
// 計(jì)算梯度的代碼
2
for (int s = 1 ; s < (GetCount() - 1) ; ++s)
{
3
for (int y = 1 ; y < (imgScaled[s]->Height() - 1) ; ++y)
{
4
for (int x = 1 ; x < (imgScaled[s]->Width() - 1) ; ++x)
{
5
float gy= imgScaled[s]->At(x, y + 1) - imgScaled[s]->At(x, y - 1);
6
float gx = imgScaled[s]->At(x + 1, y) - imgScaled[s]->At(x - 1, y);
7
8
magnitudes[s]->At(x, y) = sqrt(gx*gx + gy*gy);
9
directions[s]->At(x, y) = AtanLookupF32::Value(gy, gx);
10
}
11
}
12
}
13
14
// arctan 查表函數(shù)
15
static inline float AtanLookupF32::Value(float y,float x)
{
16
float N_DOUBLE = 4 * 4096;
17
if( x > 0.0 )
{
18
if( y > 0.0 )
19
return m_dATAN_LU[(int)(N_DOUBLE * y / ( x + y ))];
20
else
21
return -m_dATAN_LU[(int)(N_DOUBLE * (-y) / ( x - y ))];
22
}
23
24
if( x == 0.0 )
{
25
if( y > 0 )
26
return LU_PI/2;
27
else
28
return -LU_PI/2;
29
}
30
31
if( y < 0.0 )
32
return m_dATAN_LU[(int)(N_DOUBLE * y / ( x + y ))] - LU_PI;
33
else
34
return -m_dATAN_LU[(int)(N_DOUBLE * (-y) / ( x - y ))] + LU_PI;
35
}
36
從profiling的角度講,5-9行的代碼以及ATan查表函數(shù)都是要優(yōu)化到極限的,幸運(yùn)的是梯度計(jì)算部分可并行性很高,但是下標(biāo)加一減一的部分很容易使16字節(jié)對(duì)齊的要求不能符合,為此,做了兩步工作,一是讓圖像每一行的起始地址變成16字節(jié)對(duì)齊,并補(bǔ)全每行的長(zhǎng)度為16字節(jié)整數(shù)倍,二是對(duì)每一幅圖像建立一個(gè)移位的圖像,用于SSE下檢索坐標(biāo)加一減一的值。代碼如下
1
template<typename T>
2
class ImageArray
3

{
4
protected:
5
int m_nWidth;
6
int m_nHeight;
7
8
// 16字節(jié)補(bǔ)齊后的實(shí)際寬度,單位為 sizeof(float)
9
int m_nWidthActual;
10
11
// 積分圖像,用來(lái)加速圖像的區(qū)域求和用
12
ImageArray* m_pImageIntegral;
13
14
// 計(jì)算補(bǔ)足后的長(zhǎng)度
15
static __forceinline int expandAlign(int w)
{
16
return w + 3 - (w - 1) % 4;
17
}
18
19
// 數(shù)據(jù)
20
T* m_afData;
21
T** m_aafEntry;
22
23
typedef T* PointerType;
24
typedef T** EntryType;
25
26
void SetSize(int height, int width)
{
27
m_nWidth = width;
28
m_nHeight = height;
29
m_nWidthActual = expandAlign(width);
30
31
// 16字節(jié)對(duì)齊的分配
32
m_afData = (T*)_aligned_malloc(sizeof(T) * m_nWidthActual * m_nHeight, 16);
33
34
// 這一部分是加速索引,參考Wild Magic Lib里的GMatrix類
35
m_aafEntry = new PointerType[m_nHeight];
36
T* ptr = m_afData;
37
for(int i=0;i<m_nHeight;i++)
{
38
m_aafEntry[i] = ptr;
39
ptr += m_nWidthActual;
40
}
41
42
if(m_pImageIntegral)
43
delete m_pImageIntegral;
44
m_pImageIntegral = NULL;
45
}
46
47
public:
48
49
ImageArray():m_pImageIntegral(NULL)
{SetSize(0, 0);}
50
51
ImageArray(int width, int height):m_pImageIntegral(NULL)
{
52
SetSize(height, width);
53
}
54
55
ImageArray(const ImageArray& that):m_pImageIntegral(NULL)
{
56
SetSize(that.Height(), that.Width());
57
memcpy(m_afData, that.m_afData, sizeof(T) * that.m_nWidthActual * that.m_nHeight);
58
}
59
60
~ImageArray()
{
61
if(m_pImageIntegral)
62
delete m_pImageIntegral;
63
if(m_aafEntry)
64
delete []m_aafEntry;
65
66
// 對(duì)應(yīng)的釋放
67
if(m_afData)
68
_aligned_free(m_afData);
69
}
70
71
void CreateDataArray(int width, int height)
{
72
m_nWidthActual = expandAlign(width);
73
SetSize(height, m_nWidthActual);
74
m_nWidth = width;
75
m_nHeight = height;
76
}
77
78
__forceinline T& At(int x, int y)
{
79
_ASSERT(m_afData);
80
_ASSERT(x >= 0 && x < m_nWidth && y >= 0 && y < m_nHeight);
81
return m_aafEntry[y][x];
82
}
83
84
__forceinline const int Width() const
{return m_nWidth;}
85
__forceinline const int Height() const
{return m_nHeight;}
86
87
// 建立移位的圖像
88
void fillShiftedImage(int shift, ImageArray<T>& dst)
89
{
90
for(int i=0;i<m_nHeight;i++)
91
{
92
memcpy(dst[i], m_aafEntry[i] + shift, sizeof(T) * (m_nWidthActual - shift));
93
}
94
}
95
96
//
以下省略
97
};
sqrt可以用SSE指令來(lái)實(shí)現(xiàn),Atan則不行,只能用查表,但是查表函數(shù)依然很復(fù)雜,所以也必須要簡(jiǎn)化。sqrt有另一個(gè)選擇是用Wild Magic Library里的FastInvSqrt(x)函數(shù)
//----------------------------------------------------------------------------
template <class Real>
Real Math<Real>::FastInvSqrt (Real fValue)


{
// TO DO. This routine was designed for 'float'. Come up with an
// equivalent one for 'double' and specialize the templates for 'float'
// and 'double'.
float fFValue = (float)fValue;
float fHalf = 0.5f*fFValue;
int i = *(int*)&fFValue;
i = 0x5f3759df - (i >> 1);
fFValue = *(float*)&i;
fFValue = fFValue*(1.5f - fHalf*fFValue*fFValue);
return (Real)fFValue;
}

這里面用到了float格式當(dāng)int用的高級(jí)技巧,所以我看不懂 :-( 不過(guò)試過(guò)用 1.0f / FastInvSqrt(x) 來(lái)代替sqrt(x),可以略微快一點(diǎn),而且這里面的所有操作都可以用SSE實(shí)現(xiàn),所以也是可以試一下的,但是這里沒(méi)有用這個(gè)也達(dá)到了3倍的速度提升,后來(lái)就懶了一下,沒(méi)有使用,直接用SSE的四操作數(shù)sqrt操作
__m128 _mm_sqrt_ps(__m128 a );
SQRTPS

另一個(gè)問(wèn)題是ATan查表函數(shù)里的分支和浮點(diǎn)乘除法,考慮把這些全部移出到外面,放在主循環(huán)里做,算出用int表達(dá)的x,y所在的像限,以及相應(yīng)的查表索引號(hào),再傳給查表函數(shù)算,最后查表函數(shù)簡(jiǎn)化成下面這樣:
static __forceinline float ValueDirect(int y, int x, int idx)


{
x = x * 2 + y + 3;
return m_dATAN_LU[x][idx];
}

x和y代表原來(lái)的浮點(diǎn)數(shù)x,y的正負(fù),原來(lái)的代碼只留一個(gè)像限的表是節(jié)省空間的一個(gè)trick,這里我們?yōu)榱斯?jié)省加減 LU_PI 的操作,重新還原為4個(gè)表格。這里的x,y,idx全部在SSE里算好,至于整數(shù)加法與乘法,因?yàn)閮?yōu)化的空間不大,所以沒(méi)有在SSE里做,雖然SSE2下面其實(shí)提供了很多的整數(shù)操作指令的。
這樣,所有的準(zhǔn)備工作就完成了,下面是重新寫(xiě)的主循環(huán),為了節(jié)省指令數(shù),直接寫(xiě)匯編了,有關(guān)指令的細(xì)節(jié),可以參考MSDN C++ Language Reference => Compiler Intrinsics。由于沒(méi)有直接的求絕對(duì)值指令,但是有max指令,這里用了max(x, -x)的方式來(lái)求,浮點(diǎn)數(shù)與整數(shù)的轉(zhuǎn)換用SSE2的指令來(lái)做:
1
magnitudes.resize(GetCount() - 1, NULL);
2
directions.resize(GetCount() - 1, NULL);
3
4
ImageArrayf imggm;
5
6
int w = imgScaled[0]->Width();
7
int h = imgScaled[0]->Height();
8
9
int scnt = GetCount() - 1;
10
11
ImageArrayf imgsa(w, h), imgsb(w, h);
12
ImageArray<int> imgsi(w, h), imggx(w, h), imggy(w, h);
13
imggm.CreateDataArray(w, h);
14
15
for (int s = 1 ; s < (GetCount() - 1) ; ++s)
{
16
magnitudes[s] = new ImageArrayf(imgScaled[s]->Width(), imgScaled[s]->Height());
17
directions[s] = new ImageArrayf(imgScaled[s]->Width(), imgScaled[s]->Height());
18
}
19
20
__m128 ma, mb, mr;
21
__m128 na, nb, nr;
22
__m128 gl, gr, gtt, gb;
23
__m128 gx, gy, sgx, sgy, sg, sqsg;
24
__m128 gn, gi;
25
__m128i gii;
26
__m128 gzero;
27
28
memset(gzero.m128_f32, 0, sizeof(float) * 4);
29
30
for(int i=0;i<4;i++)
31

{
32
gn.m128_f32[i] = AtanLookupF32::NDOUBLE();
33
}
34
35
for (int s = 1 ; s < scnt ; ++s)
{
36
37
ImageArrayf& imgt = *imgScaled[s];
38
39
imgt.fillShiftedImage(1, imgsa);
40
imgt.fillShiftedImage(2, imgsb);
41
42
for (int y = 1 ; y < (h - 1) ; ++y)
{
43
int x;
44
for (x = 0 ; x < (w - 2) ; x += 4)
{
45
46
gl = _mm_load_ps(imgt[y] + x);
47
gr = _mm_load_ps(imgsb[y] + x);
48
gtt = _mm_load_ps(imgsa[y+1] + x);
49
gb = _mm_load_ps(imgsa[y-1] + x);
50
51
_asm
52
{
53
// x0 = right;
54
movaps xmm0, gr;
55
56
// x1 = left;
57
movaps xmm1, gl;
58
59
// x2 = top;
60
movaps xmm2, gtt;
61
62
// x3 = bottom
63
movaps xmm3, gb;
64
65
// x0 = right - left = gx
66
subps xmm0, xmm1;
67
68
// x2 = top - bottom = gy
69
subps xmm2, xmm3;
70
71
// x4 = right
72
movaps xmm4, gr;
73
74
// x6 = top
75
movaps xmm6, gtt;
76
77
// x1 = left - right = -gx;
78
subps xmm1, xmm4;
79
80
// x3 = bottom - top = -gy;
81
subps xmm3, xmm6;
82
83
// x1 = |gx|
84
maxps xmm1, xmm0;
85
86
// x3 = |gy|
87
maxps xmm3, xmm2;
88
89
// gx = x0
90
movaps gx, xmm0;
91
92
// gy = x2
93
movaps gy, xmm2;
94
95
// x1 = |gx| + |gy|
96
addps xmm1, xmm3;
97
98
// x4 = gx;
99
movaps xmm4, xmm0;
100
101
// x6 = gy;
102
movaps xmm6, xmm2;
103
104
// x4 = gx^2;
105
mulps xmm4, xmm4;
106
107
// x6 = gy^2;
108
mulps xmm6, xmm6;
109
110
// x4 = gx^2 + gy^2;
111
addps xmm4, xmm6;
112
113
// x4 = sqrt(
)
114
sqrtps xmm4, xmm4;
115
116
// sqsg = x4;
117
movaps sqsg, xmm4;
118
119
// x3 = |gy| / (|gx| + |gy|) = dy;
120
divps xmm3, xmm1;
121
122
// x1 = n;
123
movaps xmm1, gn;
124
125
// x3 = |dy| * n;
126
mulps xmm3, xmm1;
127
128
// gi = |dy| * n;
129
movaps gi, xmm3;
130
}
131
132
_mm_store_ps(imggm[y] + x, sqsg);
133
134
gx = _mm_cmpgt_ps(gx, gzero);
135
gy = _mm_cmpgt_ps(gy, gzero);
136
137
_mm_store_si128((__m128i*)(imggx[y] + x), *((__m128i*)&gx));
138
_mm_store_si128((__m128i*)(imggy[y] + x), *((__m128i*)&gy));
139
140
gii = _mm_cvtps_epi32(gi);
141
_mm_store_si128((__m128i*)(imgsi[y] + x), gii);
142
}
143
}
144
145
for (int y = 1 ; y < (h - 1) ; ++y)
{
146
for (int x = 1 ; x < (w - 1) ; x ++)
{
147
magnitudes[s]->At(x, y) = imggm[y][x-1];
148
directions[s]->At(x, y) = AtanLookupF32::ValueDirect(imggy[y][x-1], imggx[y][x-1], imgsi[y][x-1]);
149
}
150
}
151
}
152