Minor fix

This commit is contained in:
AlexeyAB 2019-01-06 15:45:10 +03:00
parent 333f1de2c3
commit c75fbb5f2e
4 changed files with 35 additions and 43 deletions

View File

@ -888,6 +888,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
if(l.c % 32 == 0) if(l.c % 32 == 0)
{ {
//printf(" l.index = %d - new XNOR \n", l.index);
int ldb_align = l.lda_align; int ldb_align = l.lda_align;
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
size_t t_intput_size = new_ldb * l.bit_align;// n; size_t t_intput_size = new_ldb * l.bit_align;// n;
@ -906,7 +908,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
free(re_packed_input); free(re_packed_input);
// convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) // slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN)
//convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output, //convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output,
// l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); // l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr);
@ -920,10 +922,11 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
int new_k = l.size*l.size*l.c / 32; int new_k = l.size*l.size*l.c / 32;
// gemm_nn_bin_32bit_packed(m, n, new_k, 1, // good for (l.c == 64)
// l.align_bit_weights, l.new_lda/32, //gemm_nn_bin_32bit_packed(m, n, new_k, 1,
// b, n, // l.align_bit_weights, l.new_lda/32,
// c, n, l.mean_arr); // b, n,
// c, n, l.mean_arr);
// // then exit from if() // // then exit from if()
@ -951,6 +954,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
else { // else (l.c % 32 != 0) else { // else (l.c % 32 != 0)
//-------------------------------------------------------- //--------------------------------------------------------
//printf(" l.index = %d - old XNOR \n", l.index);
//im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); //im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
@ -993,6 +997,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
} }
else { else {
//printf(" l.index = %d - FP32 \n", l.index);
im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);

View File

@ -489,9 +489,9 @@ void transpose_bin(uint32_t *A, uint32_t *B, const int n, const int m,
} }
static inline int popcnt_32(uint32_t val32) { static inline int popcnt_32(uint32_t val32) {
#ifdef WIN32 // Windows #ifdef WIN32 // Windows MSVS
int tmp_count = __popcnt(val32); int tmp_count = __popcnt(val32);
#else // Linux #else // Linux GCC
int tmp_count = __builtin_popcount(val32); int tmp_count = __builtin_popcount(val32);
#endif #endif
return tmp_count; return tmp_count;
@ -755,39 +755,15 @@ void gemm_nn_bin_32bit_packed(int M, int N, int K, float ALPHA,
__m256i all_1 = _mm256_set1_epi8(255); __m256i all_1 = _mm256_set1_epi8(255);
__m256i xnor256 = _mm256_andnot_si256(xor256, all_1); // xnor = not(xor(a,b)) __m256i xnor256 = _mm256_andnot_si256(xor256, all_1); // xnor = not(xor(a,b))
//_m256 count = _mm256_set_ps(
/*
__m256i count = _mm256_setr_epi32(
(int)popcnt_32(xnor256.m256i_u32[0]),
(int)popcnt_32(xnor256.m256i_u32[1]),
(int)popcnt_32(xnor256.m256i_u32[2]),
(int)popcnt_32(xnor256.m256i_u32[3]),
(int)popcnt_32(xnor256.m256i_u32[4]),
(int)popcnt_32(xnor256.m256i_u32[5]),
(int)popcnt_32(xnor256.m256i_u32[6]),
(int)popcnt_32(xnor256.m256i_u32[7]));
__m256i val2 = _mm256_set1_epi32(2);
count = _mm256_mullo_epi32(count, val2);
__m256i val32 = _mm256_set1_epi32(32);
count = _mm256_sub_epi32(count, val32);
int z;
for (z = 0; z < 8; ++z) {
C[i*ldc + j + z] += count.m256i_i32[z] * mean_val;
}
*/
__m256 count = _mm256_setr_ps( __m256 count = _mm256_setr_ps(
popcnt_32(xnor256.m256i_u32[0]), popcnt_32(_mm256_extract_epi32(xnor256, 0)),
popcnt_32(xnor256.m256i_u32[1]), popcnt_32(_mm256_extract_epi32(xnor256, 1)),
popcnt_32(xnor256.m256i_u32[2]), popcnt_32(_mm256_extract_epi32(xnor256, 2)),
popcnt_32(xnor256.m256i_u32[3]), popcnt_32(_mm256_extract_epi32(xnor256, 3)),
popcnt_32(xnor256.m256i_u32[4]), popcnt_32(_mm256_extract_epi32(xnor256, 4)),
popcnt_32(xnor256.m256i_u32[5]), popcnt_32(_mm256_extract_epi32(xnor256, 5)),
popcnt_32(xnor256.m256i_u32[6]), popcnt_32(_mm256_extract_epi32(xnor256, 6)),
popcnt_32(xnor256.m256i_u32[7])); popcnt_32(_mm256_extract_epi32(xnor256, 7)));
__m256 val2 = _mm256_set1_ps(2); __m256 val2 = _mm256_set1_ps(2);
count = _mm256_mul_ps(count, val2); // count * 2 count = _mm256_mul_ps(count, val2); // count * 2
@ -2274,17 +2250,19 @@ void gemm_nn_bin_transposed_32bit_packed(int M, int N, int K, float ALPHA,
for (i = 0; i < M; ++i) { // l.n for (i = 0; i < M; ++i) { // l.n
int j, s; int j, s;
float mean_val = mean_arr[i]; float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) // out_h*out_w;
{
float val = 0;
for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c)
{ {
register uint32_t A_PART = ((uint32_t*)A)[i*lda + s]; register uint32_t A_PART = ((uint32_t*)A)[i*lda + s];
for (j = 0; j < N; ++j) // out_h*out_w;
{
register uint32_t B_PART = ((uint32_t*)B)[j*ldb + s]; register uint32_t B_PART = ((uint32_t*)B)[j*ldb + s];
uint32_t xnor_result = ~(A_PART ^ B_PART); uint32_t xnor_result = ~(A_PART ^ B_PART);
int32_t count = popcnt_32(xnor_result); // must be Signed int int32_t count = popcnt_32(xnor_result); // must be Signed int
C[i*ldc + j] += (2 * count - 32) * mean_val; val += (2 * count - 32) * mean_val;
} }
C[i*ldc + j] += val;
} }
} }
} }
@ -2422,6 +2400,8 @@ void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
} }
} }
is_avx(); // initialize static variable
is_fma_avx2();
int t; int t;
#pragma omp parallel for #pragma omp parallel for
for (t = 0; t < M; ++t) { for (t = 0; t < M; ++t) {

View File

@ -22,6 +22,9 @@ static inline unsigned char get_bit(unsigned char const*const src, size_t index)
return val; return val;
} }
int is_avx();
int is_fma_avx2();
void float_to_bit(float *src, unsigned char *dst, size_t size); void float_to_bit(float *src, unsigned char *dst, size_t size);
void transpose_block_SSE4x4(float *A, float *B, const int n, const int m, void transpose_block_SSE4x4(float *A, float *B, const int n, const int m,

View File

@ -190,6 +190,8 @@ network make_network(int n)
return net; return net;
} }
double get_time_point();
void forward_network(network net, network_state state) void forward_network(network net, network_state state)
{ {
state.workspace = net.workspace; state.workspace = net.workspace;
@ -200,7 +202,9 @@ void forward_network(network net, network_state state)
if(l.delta){ if(l.delta){
scal_cpu(l.outputs * l.batch, 0, l.delta, 1); scal_cpu(l.outputs * l.batch, 0, l.delta, 1);
} }
//double time = get_time_point();
l.forward(l, state); l.forward(l, state);
//printf("%d - Predicted in %lf milli-seconds.\n", i, ((double)get_time_point() - time) / 1000);
state.input = l.output; state.input = l.output;
} }
} }