From 2a9f7e44ce1b73d3d56ef83f83e94f074ecac3f9 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Fri, 6 Apr 2018 15:55:00 +0300 Subject: [PATCH] Added automatic AVX support - speedup +20% on CPU x86_64 Intel Skylake --- Makefile | 2 +- src/gemm.c | 152 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 140 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 2c5fdb66..986759fe 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ NVCC=nvcc OPTS=-Ofast LDFLAGS= -lm -pthread COMMON= -CFLAGS=-Wall -Wfatal-errors +CFLAGS=-Wall -Wfatal-errors -ffp-contract=fast -mavx ifeq ($(DEBUG), 1) OPTS=-O0 -g diff --git a/src/gemm.c b/src/gemm.c index c3154ec9..4eb36fb3 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -71,21 +71,147 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA, gemm_cpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); } -void gemm_nn(int M, int N, int K, float ALPHA, - float *A, int lda, - float *B, int ldb, - float *C, int ldc) +#if (defined(__AVX__) && defined(__x86_64__)) || defined(_WIN64) + +#define OSXSAVEFlag (1UL<<27) +#define AVXFlag ((1UL<<28)|OSXSAVEFlag) +#define FMAFlag ((1UL<<12)|AVXFlag|OSXSAVEFlag) +#define CLMULFlag ((1UL<< 1)|AVXFlag|OSXSAVEFlag) +#define VAESFlag ((1UL<<25)|AVXFlag|OSXSAVEFlag) + +#include + +#ifdef _WIN64 +#include +#include +#include +#include + +#else // Linux GCC/Clang +#include +#include +#include +#include + +void asm_cpuid(uint32_t* abcd, uint32_t eax) { - int i,j,k; - for(i = 0; i < M; ++i){ - for(k = 0; k < K; ++k){ - register float A_PART = ALPHA*A[i*lda+k]; - for(j = 0; j < N; ++j){ - C[i*ldc+j] += A_PART*B[k*ldb+j]; - } - } - } + uint32_t ebx = 0, edx = 0, ecx = 0; + + // EBX is saved to EDI and later restored + __asm__("movl %%ebx, %%edi;" + "cpuid;" + "xchgl %%ebx, %%edi;" + : "=D"(ebx), + "+a"(eax), "+c"(ecx), "=d"(edx)); + + abcd[0] = eax; + abcd[1] = ebx; + abcd[2] = ecx; + abcd[3] = edx; } +#endif + +inline int simd_detect_x86(unsigned int idFeature) +{ + uint32_t regs[4]; // EAX, EBX, ECX, EDX; +#ifdef _WIN32 + __cpuid(regs, 0); + if (regs[0] > 1U) __cpuid(regs, 1); +#else + asm_cpuid(regs, 0); + if (regs[0] > 1U) asm_cpuid(regs, 0); +#endif + + if ((regs[2] & idFeature) != idFeature) + return 0; + return 1; +} + +inline int is_fma_avx() { + static int result = -1; + if (result == -1) { + result = simd_detect_x86(AVXFlag); + if (result == 1) printf(" Used AVX \n"); + else printf(" Not used AVX \n"); + } + return result; +} + +// https://software.intel.com/sites/landingpage/IntrinsicsGuide +void gemm_nn(int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float *C, int ldc) +{ + int i, j, k; + if (is_fma_avx() == 1) { // AVX + for (i = 0; i < M; ++i) { + for (k = 0; k < K; ++k) { + float A_PART = ALPHA*A[i*lda + k]; + __m256 a256, b256, c256, result256; // AVX + a256 = _mm256_set1_ps(A_PART); + for (j = 0; j < N - 8; j += 8) { + b256 = _mm256_loadu_ps(&B[k*ldb + j]); + c256 = _mm256_loadu_ps(&C[i*ldc + j]); + // FMA - Intel Haswell (2013), AMD Piledriver (2012) + //result256 = _mm256_fmadd_ps(a256, b256, c256); + result256 = _mm256_mul_ps(a256, b256); + result256 = _mm256_add_ps(result256, c256); + _mm256_storeu_ps(&C[i*ldc + j], result256); + } + + int prev_end = (N % 8 == 0) ? (N - 8) : (N / 8) * 8; + for (j = prev_end; j < N; ++j) + C[i*ldc + j] += A_PART*B[k*ldb + j]; + } + } + } + else { + for (i = 0; i < M; ++i) { + for (k = 0; k < K; ++k) { + register float A_PART = ALPHA*A[i*lda + k]; + for (j = 0; j < N; ++j) { + C[i*ldc + j] += A_PART*B[k*ldb + j]; + } + /* // SSE + __m128 a128, b128, c128, result128; // SSE + a128 = _mm_set1_ps(A_PART); + for (j = 0; j < N - 4; j += 4) { + b128 = _mm_loadu_ps(&B[k*ldb + j]); + c128 = _mm_loadu_ps(&C[i*ldc + j]); + //result128 = _mm_fmadd_ps(a128, b128, c128); + result128 = _mm_mul_ps(a128, b128); + result128 = _mm_add_ps(result128, c128); + _mm_storeu_ps(&C[i*ldc + j], result128); + } + + int prev_end = (N % 4 == 0) ? (N - 4) : (N / 4) * 4; + for (j = prev_end; j < N; ++j){ + C[i*ldc + j] += A_PART*B[k*ldb + j]; + } + */ + } + } + } +} +#else + +void gemm_nn(int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float *C, int ldc) +{ + int i, j, k; + for (i = 0; i < M; ++i) { + for (k = 0; k < K; ++k) { + register float A_PART = ALPHA*A[i*lda + k]; + for (j = 0; j < N; ++j) { + C[i*ldc + j] += A_PART*B[k*ldb + j]; + } + } + } +} +#endif // __x86_64 void gemm_nt(int M, int N, int K, float ALPHA, float *A, int lda,