目次

3. プログラムの高速化

3.1 SIMDによるベクトル化

OpenMOMの計算時間の大部分は式(2-6-5)の第6式の和が占めます。
これは複素数ベクトルの内積ですので、SIMDを用いると計算時間を短縮することができます。
リスト3-1-1にソースコードを示します(x=a・b)。 これは[6]のコードを複素数に拡張したものです。
複素数は通常は構造体(AOS=array of structure)で扱いますが、 ここでは効率よくメモリーにアクセスするために配列(SOA=structure of array)に分解しています。 SOAはコードの可読性が少し下がりますが、高速化のために必要になることがあります。 下記の関数の引数はベクトルaとbの実部と虚部の配列となっており、 これらは16バイト(SSE)または32バイト(AVX)でアラインされている必要があります。

リスト3-1-1 SIMDを用いた複素数ベクトルの内積の計算(cdot.c)


     1	#include <immintrin.h>
     2	
     3	void cdot(
     4		int simd, int n,
     5		const float *a_r, const float *a_i,
     6		const float *b_r, const float *b_i,
     7		float *x_r, float *x_i)
     8	{
     9		float sum_r = 0;
    10		float sum_i = 0;
    11	
    12		if      (simd == 0) {
    13			// no SIMD
    14			for (int i = 0; i < n; i++) {
    15				sum_r += (a_r[i] * b_r[i]) - (a_i[i] * b_i[i]);
    16				sum_i += (a_r[i] * b_i[i]) + (a_i[i] * b_r[i]);
    17			}
    18		}
    19		else if (simd == 1) {
    20			// SSE
    21	#ifdef _WIN32
    22			__declspec(align(16)) float sumr[4], sumi[4];
    23	#else
    24			__attribute__((aligned(16))) float sumr[4], sumi[4];
    25	#endif
    26			__m128 sr, si, v1r, v1i, v2r, v2i;
    27			sr = _mm_setzero_ps();
    28			si = _mm_setzero_ps();
    29			for (int i = 0; i < n; i += 4) {
    30				v1r = _mm_load_ps(a_r + i);
    31				v1i = _mm_load_ps(a_i + i);
    32				v2r = _mm_load_ps(b_r + i);
    33				v2i = _mm_load_ps(b_i + i);
    34				sr = _mm_add_ps(sr, _mm_sub_ps(_mm_mul_ps(v1r, v2r), _mm_mul_ps(v1i, v2i)));
    35				si = _mm_add_ps(si, _mm_add_ps(_mm_mul_ps(v1r, v2i), _mm_mul_ps(v1i, v2r)));
    36			}
    37			_mm_store_ps(sumr, sr);
    38			_mm_store_ps(sumi, si);
    39			sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3];
    40			sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3];
    41		}
    42		else if (simd == 2) {
    43			// AVX
    44	#ifdef _WIN32
    45			__declspec(align(32)) float sumr[8], sumi[8];
    46	#else
    47			__attribute__((aligned(32))) float sumr[8], sumi[8];
    48	#endif
    49			__m256 sr, si, v1r, v1i, v2r, v2i;
    50			sr = _mm256_setzero_ps();
    51			si = _mm256_setzero_ps();
    52			for (int i = 0; i < n; i += 8) {
    53				v1r = _mm256_load_ps(a_r + i);
    54				v1i = _mm256_load_ps(a_i + i);
    55				v2r = _mm256_load_ps(b_r + i);
    56				v2i = _mm256_load_ps(b_i + i);
    57				sr = _mm256_add_ps(sr, _mm256_sub_ps(_mm256_mul_ps(v1r, v2r), _mm256_mul_ps(v1i, v2i)));
    58				si = _mm256_add_ps(si, _mm256_add_ps(_mm256_mul_ps(v1r, v2i), _mm256_mul_ps(v1i, v2r)));
    59			}
    60			_mm256_store_ps(sumr, sr);
    61			_mm256_store_ps(sumi, si);
    62			sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3]
    63			      + sumr[4] + sumr[5] + sumr[6] + sumr[7];
    64			sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3]
    65			      + sumi[4] + sumi[5] + sumi[6] + sumi[7];
    66		}
    67	
    68		*x_r = sum_r;
    69		*x_i = sum_i;
    70	}