リスト3-1-1 SIMDを用いた複素数ベクトルの内積の計算(cdot.c)
1 #include <immintrin.h> 2 3 void cdot( 4 int simd, int n, 5 const float *a_r, const float *a_i, 6 const float *b_r, const float *b_i, 7 float *x_r, float *x_i) 8 { 9 float sum_r = 0; 10 float sum_i = 0; 11 12 if (simd == 0) { 13 // no SIMD 14 for (int i = 0; i < n; i++) { 15 sum_r += (a_r[i] * b_r[i]) - (a_i[i] * b_i[i]); 16 sum_i += (a_r[i] * b_i[i]) + (a_i[i] * b_r[i]); 17 } 18 } 19 else if (simd == 1) { 20 // SSE 21 #ifdef _WIN32 22 __declspec(align(16)) float sumr[4], sumi[4]; 23 #else 24 __attribute__((aligned(16))) float sumr[4], sumi[4]; 25 #endif 26 __m128 sr, si, v1r, v1i, v2r, v2i; 27 sr = _mm_setzero_ps(); 28 si = _mm_setzero_ps(); 29 for (int i = 0; i < n; i += 4) { 30 v1r = _mm_load_ps(a_r + i); 31 v1i = _mm_load_ps(a_i + i); 32 v2r = _mm_load_ps(b_r + i); 33 v2i = _mm_load_ps(b_i + i); 34 sr = _mm_add_ps(sr, _mm_sub_ps(_mm_mul_ps(v1r, v2r), _mm_mul_ps(v1i, v2i))); 35 si = _mm_add_ps(si, _mm_add_ps(_mm_mul_ps(v1r, v2i), _mm_mul_ps(v1i, v2r))); 36 } 37 _mm_store_ps(sumr, sr); 38 _mm_store_ps(sumi, si); 39 sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3]; 40 sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3]; 41 } 42 else if (simd == 2) { 43 // AVX 44 #ifdef _WIN32 45 __declspec(align(32)) float sumr[8], sumi[8]; 46 #else 47 __attribute__((aligned(32))) float sumr[8], sumi[8]; 48 #endif 49 __m256 sr, si, v1r, v1i, v2r, v2i; 50 sr = _mm256_setzero_ps(); 51 si = _mm256_setzero_ps(); 52 for (int i = 0; i < n; i += 8) { 53 v1r = _mm256_load_ps(a_r + i); 54 v1i = _mm256_load_ps(a_i + i); 55 v2r = _mm256_load_ps(b_r + i); 56 v2i = _mm256_load_ps(b_i + i); 57 sr = _mm256_add_ps(sr, _mm256_sub_ps(_mm256_mul_ps(v1r, v2r), _mm256_mul_ps(v1i, v2i))); 58 si = _mm256_add_ps(si, _mm256_add_ps(_mm256_mul_ps(v1r, v2i), _mm256_mul_ps(v1i, v2r))); 59 } 60 _mm256_store_ps(sumr, sr); 61 _mm256_store_ps(sumi, si); 62 sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3] 63 + sumr[4] + sumr[5] + sumr[6] + sumr[7]; 64 sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3] 65 + sumi[4] + sumi[5] + sumi[6] + sumi[7]; 66 } 67 68 *x_r = sum_r; 69 *x_i = sum_i; 70 }