Fix layer_normalize gradients (#3001)

* Fix layer_normalize gradients

* fix layer_norm CPU

* attempt to fix the cuda version

* fix gamma_grad and beta_grad

* update cuda test

* use a block of size 1 to avoid race conditions

* improve the speed of CUDA path of layer_norm

* improve the speed of CUDA path of layer_norm
This commit is contained in:
Adrià 2024-09-01 22:05:09 +09:00 committed by GitHub
parent 27a0135220
commit 253098eb1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 213 additions and 149 deletions

View File

@ -1270,22 +1270,19 @@ namespace dlib
const tensor& beta
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
@ -1296,46 +1293,53 @@ namespace dlib
// first compute means and invstds
means = 0;
invstds = 0;
const auto p_invstds = invstds.host();
const auto p_means = means.host();
auto p_src = src.host();
const float* p_src = src.host();
float* p_invstds = invstds.host();
float* p_means = means.host();
const long num = src.nr() * src.nc();
// compute means, and sum of squares
for (long n = 0; n < src.num_samples(); ++n)
{
for (long k = 0; k < src.k(); ++k)
{
for (long i = 0; i < num; ++i)
{
float val = p_src[n*num+i];
p_means[n] += val;
p_invstds[n] += val*val;
p_means[n] += *p_src;
p_invstds[n] += (*p_src) * (*p_src);
++p_src;
}
}
means /= num;
invstds /= num;
}
means /= src.k() * num;
invstds /= src.k () * num;
// copy data back to host
invstds.host(); means.host();
invstds.host();
means.host();
// compute variances
for (long n = 0; n < src.num_samples(); ++n)
{
auto var = p_invstds[n] - p_means[n] * p_means[n];
p_invstds[n] = 1.0f / std::sqrt(var + eps);
p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps);
}
p_src = src.host();
auto p_dest = dest.host();
auto p_gamma = gamma.host();
auto p_beta = beta.host();
float* p_dest = dest.host();
const float* p_gamma = gamma.host();
const float* p_beta = beta.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long k = 0; k < src.k(); ++k)
{
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src - p_means[n])*p_invstds[n];
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
*p_dest = (*p_src - p_means[n]) * p_invstds[n];
*p_dest = (*p_dest) * p_gamma[k] + p_beta[k];
++p_src;
++p_dest;
}
}
}
}
void layer_normalize_gradient (
const double eps,
@ -1346,22 +1350,26 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
)
{
const long num = src.k() * src.nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma_grad.nr());
DLIB_CASSERT(src.nc() == beta_grad.nc());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
gamma_grad = 0;
auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
@ -1370,7 +1378,6 @@ namespace dlib
const auto p_invstds = invstds.host();
const auto p_means = means.host();
resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
@ -1379,50 +1386,58 @@ namespace dlib
const auto p_dmeans = dmeans.host();
for (long n = 0; n < src.num_samples(); ++n)
{
const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f);
for (long k = 0; k < src.k(); ++k)
{
for (long i = 0; i < num; ++i)
{
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
p_beta_grad[i] += *p_grad;
p_gamma_grad[i] += (*p_grad)*x_hat;
const float x_hat = (*p_src - p_means[n]) * p_invstds[n];
p_beta_grad[k] += *p_grad;
p_gamma_grad[k] += (*p_grad) * x_hat;
const float dx = *p_grad * p_gamma[n];
const float dx = *p_grad * p_gamma[k];
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow;
++p_grad;
++p_src;
}
}
}
const float invnum = 1.0f/num;
p_grad = gradient_input.host();
p_src = src.host();
const float invnum = 1.0f / (src.k() * num);
for (long n = 0; n < src.num_samples(); ++n)
{
for (long k = 0; k < src.k(); ++k)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[i];
const float dx = *p_grad * p_gamma[k];
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum;
++p_grad;
++p_src;
}
}
}
p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long k = 0; k < src.k(); ++k)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[i];
*p_src_grad += dx*p_invstds[n] +
p_dvars[n] *2*(*p_src - p_means[n])*invnum +
p_dmeans[n]*invnum;
const float dx = *p_grad * p_gamma[k];
*p_src_grad += dx * p_invstds[n] +
p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum +
p_dmeans[n] * invnum;
++p_grad;
++p_src;
@ -1430,6 +1445,7 @@ namespace dlib
}
}
}
}
// -----------------------------------------------------------------------------------

View File

@ -250,7 +250,9 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
);
// -----------------------------------------------------------------------------------

View File

@ -2085,21 +2085,32 @@ namespace dlib
// ----------------------------------------------------------------------------------------
__global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num)
__global__ void _cuda_layer_normalize(
float* out,
const float* s,
float* m,
float* v,
const float* g,
const float* b,
float eps,
size_t ns,
size_t k,
size_t num
)
{
// compute means and sum of squares
for (auto n : grid_stride_range_y(0, ns))
{
auto p = s + n * num;
const auto ps = s + n * k * num;
float means = 0;
float invstds = 0;
for (auto i : grid_stride_range(0, num))
for (auto i : grid_stride_range(0, k * num))
{
means += p[i];
invstds += p[i] * p[i];
means += ps[i];
invstds += ps[i] * ps[i];
}
warp_reduce_atomic_add(m[n], means/num);
warp_reduce_atomic_add(v[n], invstds/num);
warp_reduce_atomic_add(m[n], means / (k * num));
warp_reduce_atomic_add(v[n], invstds / (k * num));
}
__syncthreads();
@ -2108,61 +2119,19 @@ namespace dlib
{
for (auto i : grid_stride_range(0, 1))
{
auto var = v[n] - m[n] * m[n];
v[n] = 1.0f / std::sqrt(var + eps);
v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps);
}
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, num))
const auto ps = s + n * k * num;
const auto pout = out + n * k * num;
for (auto i : grid_stride_range(0, k * num))
{
const float val = (s[n*num+i]-m[n])*v[n];
out[n*num+i] = val*g[i]+b[i];
}
}
}
__global__ void _cuda_layer_normalize_gradient(float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t num)
{
for (auto n : grid_stride_range_y(0, ns))
{
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float x_hat = (s[idx] - m[n])*v[n];
bg[i] += gi[idx];
gg[i] += gi[idx]*x_hat;
const float dx = gi[idx] * g[n];
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
}
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
float temp_dm = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[i];
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
}
warp_reduce_atomic_add(dm[n], temp_dm);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[i];
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
pout[i] = (ps[i] - m[n]) * v[n];
pout[i] = pout[i] * g[i / num] + b[i / num];
}
}
}
@ -2177,22 +2146,20 @@ namespace dlib
const tensor& beta
)
{
const long num = src.k() * src.nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
@ -2201,8 +2168,78 @@ namespace dlib
invstds.set_size(src.num_samples());
means = 0;
invstds = 0;
launch_kernel(_cuda_layer_normalize, max_jobs(num, src.num_samples()), dest.device(), src.device(),
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), num);
launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(),
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_layer_normalize_gradient(
float* out,
float* gg,
float* bg,
const float* s,
const float* gi,
const float* m,
const float* v,
const float* g,
float* dm,
float* dv,
float eps,
size_t ns,
size_t ks,
size_t num)
{
for (auto nk : grid_stride_range_y(0, ns * ks))
{
const auto n = nk / ks;
const auto k = nk % ks;
const auto ps = s + (n * ks + k) * num;
const auto pgi = gi + (n * ks + k) * num;
const float invstd_pow = -0.5 * std::pow(v[n], 3.0f);
float temp_bg = 0;
float temp_gg = 0;
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
const float x_hat = (ps[i] - m[n]) * v[n];
const float dx = pgi[i] * g[i / num];
temp_bg += pgi[i];
temp_gg += pgi[i] * x_hat;
temp_dv += dx * (ps[i] - m[n]) * invstd_pow;
}
warp_reduce_atomic_add(bg[k], temp_bg);
warp_reduce_atomic_add(gg[k], temp_gg);
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
const float invnum = 1.0f / (ks * num);
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = s + n * ks * num;
const auto pgi = gi + n * ks * num;
float temp_dm = 0;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * g[i / num];
temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum;
}
warp_reduce_atomic_add(dm[n], temp_dm);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = s + n * ks * num;
const auto pgi = gi + n * ks * num;
const auto pout = out + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * g[i / num];
pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum;
}
}
}
void layer_normalize_gradient (
@ -2214,32 +2251,33 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
)
{
const long num = src.k() * src.nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma.nr());
DLIB_CASSERT(src.nc() == gamma.nc());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
gamma_grad = 0;
resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
dmeans = 0;
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(num, src.num_samples()),
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()),
src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(),
gradient_input.device(), means.device(), invstds.device(), gamma.device(),
dmeans.device(), dvars.device(), eps, src.num_samples(), num);
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
}
// ----------------------------------------------------------------------------------------

View File

@ -357,7 +357,9 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
);
// -----------------------------------------------------------------------------------

View File

@ -684,13 +684,15 @@ namespace dlib { namespace tt
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
)
{
#ifdef DLIB_USE_CUDA
cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars);
#else
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars);
#endif
}

View File

@ -814,13 +814,13 @@ namespace dlib { namespace tt
/*!
requires
- eps > 0
- src.num_samples() == gamma.size() == beta.size()
- src.k() == gamma.size() == beta.size()
- gamma.num_samples() == gamma.nr() == gamma.nc() == 1
- have_same_dimensions(gamma, beta) == true
- beta.num_samples() ==beta.nr() ==gamma.nc() == 1
ensures
- have_same_dimensions(#dest, src) == true
- #means.size() == invstds.size() == src.num_samples()
- #dest == the normalized version of src.
- #dest == the normalized version of src, sample-wise.
- #means == the mean values of the contents of src.
- #invstds == 1/(the standard deviation values of the contents of src).
!*/
@ -834,7 +834,9 @@ namespace dlib { namespace tt
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
);
/*!
requires
@ -847,8 +849,6 @@ namespace dlib { namespace tt
- have_same_dimensions(gamma, beta_grad) == true
- means.size() == src.num_samples()
- invstds.size() == src.num_samples()
- have_same_dimensions(means, gamma) == true
- have_same_dimensions(invstds, gamma) == true
ensures
- Let f(src,gamma,beta) == dot(gradient_input, dest output of
layer_normalize(eps,dest,means,invstds,src,gamma,beta))

View File

@ -1403,7 +1403,7 @@ namespace dlib
template <typename SUBNET>
void setup (const SUBNET& sub)
{
gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
gamma = alias_tensor(1, sub.get_output().k());
beta = gamma;
params.set_size(gamma.size()+beta.size());
@ -1426,7 +1426,7 @@ namespace dlib
auto g = gamma(params, 0);
auto g_grad = gamma(params_grad, 0);
auto b_grad = beta(params_grad, gamma.size());
tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad);
tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad, dmeans, dvars);
}
const tensor& get_layer_params() const { return params; };
@ -1493,6 +1493,7 @@ namespace dlib
resizable_tensor params;
alias_tensor gamma, beta;
resizable_tensor means, invstds;
resizable_tensor dmeans, dvars;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;

View File

@ -607,7 +607,7 @@ namespace
tt::tensor_rand rnd(0);
rnd.fill_uniform(x);
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc());
resizable_tensor gamma(1, x.k(), 1, 1), beta(1, x.k(), 1, 1);
gamma = 1;
beta = 0;
const float eps = 1e-5;
@ -639,16 +639,19 @@ namespace
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
resizable_tensor gradient_input(x);
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc());
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc());
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), 1, 1), beta_grad_cpu(1, x.k(), 1, 1);
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), 1, 1), beta_grad_cuda(1, x.k(), 1, 1);
resizable_tensor dmeans_cpu, dvars_cpu, dmeans_cuda, dvars_cuda;
rnd.fill_gaussian(gradient_input);
src_grad_cpu = 0;
src_grad_cuda = 0;
cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu);
cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda);
cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu, dmeans_cpu, dvars_cpu);
cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda, dmeans_cuda, dvars_cuda);
DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(beta_grad_cpu) - mat(beta_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(dmeans_cpu) - mat(dmeans_cuda))) < 1e-4);
DLIB_TEST(max(abs(mat(dvars_cpu) - mat(dvars_cuda))) < 1e-4);
#endif
}