mirror of https://github.com/davisking/dlib.git
Fix layer_normalize gradients (#3001)
* Fix layer_normalize gradients * fix layer_norm CPU * attempt to fix the cuda version * fix gamma_grad and beta_grad * update cuda test * use a block of size 1 to avoid race conditions * improve the speed of CUDA path of layer_norm * improve the speed of CUDA path of layer_norm
This commit is contained in:
parent
27a0135220
commit
253098eb1b
|
@ -1270,22 +1270,19 @@ namespace dlib
|
|||
const tensor& beta
|
||||
)
|
||||
{
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
DLIB_CASSERT(
|
||||
have_same_dimensions(gamma, beta) &&
|
||||
src.k() == gamma.k() &&
|
||||
src.nr() == gamma.nr() &&
|
||||
src.nc() == gamma.nc() &&
|
||||
gamma.k() == src.k() &&
|
||||
gamma.nr() == 1 &&
|
||||
gamma.nc() == 1 &&
|
||||
eps > 0,
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
"\ngamma.nc(): " << gamma.nc() <<
|
||||
"\nbeta.k(): " << beta.k() <<
|
||||
"\nbeta.nr(): " << beta.nr() <<
|
||||
"\nbeta.nc(): " << beta.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
|
@ -1296,46 +1293,53 @@ namespace dlib
|
|||
// first compute means and invstds
|
||||
means = 0;
|
||||
invstds = 0;
|
||||
const auto p_invstds = invstds.host();
|
||||
const auto p_means = means.host();
|
||||
auto p_src = src.host();
|
||||
const float* p_src = src.host();
|
||||
float* p_invstds = invstds.host();
|
||||
float* p_means = means.host();
|
||||
const long num = src.nr() * src.nc();
|
||||
// compute means, and sum of squares
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
for (long k = 0; k < src.k(); ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
float val = p_src[n*num+i];
|
||||
p_means[n] += val;
|
||||
p_invstds[n] += val*val;
|
||||
p_means[n] += *p_src;
|
||||
p_invstds[n] += (*p_src) * (*p_src);
|
||||
++p_src;
|
||||
}
|
||||
}
|
||||
means /= num;
|
||||
invstds /= num;
|
||||
}
|
||||
means /= src.k() * num;
|
||||
invstds /= src.k () * num;
|
||||
// copy data back to host
|
||||
invstds.host(); means.host();
|
||||
invstds.host();
|
||||
means.host();
|
||||
|
||||
// compute variances
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
auto var = p_invstds[n] - p_means[n] * p_means[n];
|
||||
p_invstds[n] = 1.0f / std::sqrt(var + eps);
|
||||
p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps);
|
||||
}
|
||||
|
||||
p_src = src.host();
|
||||
auto p_dest = dest.host();
|
||||
auto p_gamma = gamma.host();
|
||||
auto p_beta = beta.host();
|
||||
float* p_dest = dest.host();
|
||||
const float* p_gamma = gamma.host();
|
||||
const float* p_beta = beta.host();
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
for (long k = 0; k < src.k(); ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
*p_dest = (*p_src - p_means[n]) * p_invstds[n];
|
||||
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
|
||||
*p_dest = (*p_dest) * p_gamma[k] + p_beta[k];
|
||||
++p_src;
|
||||
++p_dest;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void layer_normalize_gradient (
|
||||
const double eps,
|
||||
|
@ -1346,22 +1350,26 @@ namespace dlib
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
)
|
||||
{
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
const long num = src.nr() * src.nc();
|
||||
DLIB_CASSERT(src.num_samples() == means.size());
|
||||
DLIB_CASSERT(src.num_samples() == invstds.size());
|
||||
DLIB_CASSERT(src.k() == gamma.k());
|
||||
DLIB_CASSERT(src.nr() == gamma_grad.nr());
|
||||
DLIB_CASSERT(src.nc() == beta_grad.nc());
|
||||
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(gamma.k() == src.k());
|
||||
DLIB_CASSERT(gamma.nr() == 1);
|
||||
DLIB_CASSERT(gamma.nc() == 1);
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(eps > 0);
|
||||
|
||||
beta_grad = 0;
|
||||
gamma_grad = 0;
|
||||
|
||||
auto p_grad = gradient_input.host();
|
||||
auto p_src = src.host();
|
||||
const auto p_gamma = gamma.host();
|
||||
|
@ -1370,7 +1378,6 @@ namespace dlib
|
|||
const auto p_invstds = invstds.host();
|
||||
const auto p_means = means.host();
|
||||
|
||||
resizable_tensor dvars, dmeans;
|
||||
dvars.copy_size(invstds);
|
||||
dmeans.copy_size(means);
|
||||
dvars = 0;
|
||||
|
@ -1379,57 +1386,66 @@ namespace dlib
|
|||
const auto p_dmeans = dmeans.host();
|
||||
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f);
|
||||
for (long k = 0; k < src.k(); ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float x_hat = (*p_src - p_means[n]) * p_invstds[n];
|
||||
p_beta_grad[i] += *p_grad;
|
||||
p_gamma_grad[i] += (*p_grad)*x_hat;
|
||||
p_beta_grad[k] += *p_grad;
|
||||
p_gamma_grad[k] += (*p_grad) * x_hat;
|
||||
|
||||
const float dx = *p_grad * p_gamma[n];
|
||||
const float dx = *p_grad * p_gamma[k];
|
||||
|
||||
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
|
||||
p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow;
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const float invnum = 1.0f/num;
|
||||
p_grad = gradient_input.host();
|
||||
p_src = src.host();
|
||||
const float invnum = 1.0f / (src.k() * num);
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
for (long k = 0; k < src.k(); ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float dx = *p_grad * p_gamma[i];
|
||||
const float dx = *p_grad * p_gamma[k];
|
||||
|
||||
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
|
||||
p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum;
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
p_grad = gradient_input.host();
|
||||
p_src = src.host();
|
||||
auto p_src_grad = src_grad.host();
|
||||
for (long n = 0; n < src.num_samples(); ++n)
|
||||
{
|
||||
for (long k = 0; k < src.k(); ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float dx = *p_grad * p_gamma[i];
|
||||
const float dx = *p_grad * p_gamma[k];
|
||||
|
||||
*p_src_grad += dx * p_invstds[n] +
|
||||
p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum +
|
||||
p_dmeans[n] * invnum;
|
||||
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
++p_src_grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -250,7 +250,9 @@ namespace dlib
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
|
|
@ -2085,21 +2085,32 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num)
|
||||
__global__ void _cuda_layer_normalize(
|
||||
float* out,
|
||||
const float* s,
|
||||
float* m,
|
||||
float* v,
|
||||
const float* g,
|
||||
const float* b,
|
||||
float eps,
|
||||
size_t ns,
|
||||
size_t k,
|
||||
size_t num
|
||||
)
|
||||
{
|
||||
// compute means and sum of squares
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
auto p = s + n * num;
|
||||
const auto ps = s + n * k * num;
|
||||
float means = 0;
|
||||
float invstds = 0;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
for (auto i : grid_stride_range(0, k * num))
|
||||
{
|
||||
means += p[i];
|
||||
invstds += p[i] * p[i];
|
||||
means += ps[i];
|
||||
invstds += ps[i] * ps[i];
|
||||
}
|
||||
warp_reduce_atomic_add(m[n], means/num);
|
||||
warp_reduce_atomic_add(v[n], invstds/num);
|
||||
warp_reduce_atomic_add(m[n], means / (k * num));
|
||||
warp_reduce_atomic_add(v[n], invstds / (k * num));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
@ -2108,61 +2119,19 @@ namespace dlib
|
|||
{
|
||||
for (auto i : grid_stride_range(0, 1))
|
||||
{
|
||||
auto var = v[n] - m[n] * m[n];
|
||||
v[n] = 1.0f / std::sqrt(var + eps);
|
||||
v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
const auto ps = s + n * k * num;
|
||||
const auto pout = out + n * k * num;
|
||||
for (auto i : grid_stride_range(0, k * num))
|
||||
{
|
||||
const float val = (s[n*num+i]-m[n])*v[n];
|
||||
out[n*num+i] = val*g[i]+b[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _cuda_layer_normalize_gradient(float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t num)
|
||||
{
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
float temp_dv = 0;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float x_hat = (s[idx] - m[n])*v[n];
|
||||
bg[i] += gi[idx];
|
||||
gg[i] += gi[idx]*x_hat;
|
||||
|
||||
const float dx = gi[idx] * g[n];
|
||||
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
|
||||
}
|
||||
warp_reduce_atomic_add(dv[n], temp_dv);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
float temp_dm = 0;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float dx = gi[idx]*g[i];
|
||||
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
|
||||
}
|
||||
warp_reduce_atomic_add(dm[n], temp_dm);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float dx = gi[idx]*g[i];
|
||||
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
|
||||
pout[i] = (ps[i] - m[n]) * v[n];
|
||||
pout[i] = pout[i] * g[i / num] + b[i / num];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2177,22 +2146,20 @@ namespace dlib
|
|||
const tensor& beta
|
||||
)
|
||||
{
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
const long num = src.nr() * src.nc();
|
||||
DLIB_CASSERT(
|
||||
have_same_dimensions(gamma, beta) &&
|
||||
src.k() == gamma.k() &&
|
||||
src.nr() == gamma.nr() &&
|
||||
src.nc() == gamma.nc() &&
|
||||
gamma.k() == src.k() &&
|
||||
gamma.nr() == 1 &&
|
||||
gamma.nc() == 1 &&
|
||||
eps > 0,
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
"\ngamma.nc(): " << gamma.nc() <<
|
||||
"\nbeta.k(): " << beta.k() <<
|
||||
"\nbeta.nr(): " << beta.nr() <<
|
||||
"\nbeta.nc(): " << beta.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
|
@ -2201,8 +2168,78 @@ namespace dlib
|
|||
invstds.set_size(src.num_samples());
|
||||
means = 0;
|
||||
invstds = 0;
|
||||
launch_kernel(_cuda_layer_normalize, max_jobs(num, src.num_samples()), dest.device(), src.device(),
|
||||
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), num);
|
||||
launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(),
|
||||
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_layer_normalize_gradient(
|
||||
float* out,
|
||||
float* gg,
|
||||
float* bg,
|
||||
const float* s,
|
||||
const float* gi,
|
||||
const float* m,
|
||||
const float* v,
|
||||
const float* g,
|
||||
float* dm,
|
||||
float* dv,
|
||||
float eps,
|
||||
size_t ns,
|
||||
size_t ks,
|
||||
size_t num)
|
||||
{
|
||||
for (auto nk : grid_stride_range_y(0, ns * ks))
|
||||
{
|
||||
const auto n = nk / ks;
|
||||
const auto k = nk % ks;
|
||||
const auto ps = s + (n * ks + k) * num;
|
||||
const auto pgi = gi + (n * ks + k) * num;
|
||||
const float invstd_pow = -0.5 * std::pow(v[n], 3.0f);
|
||||
float temp_bg = 0;
|
||||
float temp_gg = 0;
|
||||
float temp_dv = 0;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
const float x_hat = (ps[i] - m[n]) * v[n];
|
||||
const float dx = pgi[i] * g[i / num];
|
||||
temp_bg += pgi[i];
|
||||
temp_gg += pgi[i] * x_hat;
|
||||
temp_dv += dx * (ps[i] - m[n]) * invstd_pow;
|
||||
}
|
||||
warp_reduce_atomic_add(bg[k], temp_bg);
|
||||
warp_reduce_atomic_add(gg[k], temp_gg);
|
||||
warp_reduce_atomic_add(dv[n], temp_dv);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float invnum = 1.0f / (ks * num);
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
const auto ps = s + n * ks * num;
|
||||
const auto pgi = gi + n * ks * num;
|
||||
float temp_dm = 0;
|
||||
for (auto i : grid_stride_range(0, ks * num))
|
||||
{
|
||||
const float dx = pgi[i] * g[i / num];
|
||||
temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum;
|
||||
}
|
||||
warp_reduce_atomic_add(dm[n], temp_dm);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
const auto ps = s + n * ks * num;
|
||||
const auto pgi = gi + n * ks * num;
|
||||
const auto pout = out + n * ks * num;
|
||||
for (auto i : grid_stride_range(0, ks * num))
|
||||
{
|
||||
const float dx = pgi[i] * g[i / num];
|
||||
pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void layer_normalize_gradient (
|
||||
|
@ -2214,32 +2251,33 @@ namespace dlib
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
)
|
||||
{
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
const long num = src.nr() * src.nc();
|
||||
DLIB_CASSERT(src.num_samples() == means.size());
|
||||
DLIB_CASSERT(src.num_samples() == invstds.size());
|
||||
DLIB_CASSERT(src.k() == gamma.k());
|
||||
DLIB_CASSERT(src.nr() == gamma.nr());
|
||||
DLIB_CASSERT(src.nc() == gamma.nc());
|
||||
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(gamma.k() == src.k());
|
||||
DLIB_CASSERT(gamma.nr() == 1);
|
||||
DLIB_CASSERT(gamma.nc() == 1);
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(eps > 0);
|
||||
|
||||
beta_grad = 0;
|
||||
gamma_grad = 0;
|
||||
resizable_tensor dvars, dmeans;
|
||||
dvars.copy_size(invstds);
|
||||
dmeans.copy_size(means);
|
||||
dvars = 0;
|
||||
dmeans = 0;
|
||||
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(num, src.num_samples()),
|
||||
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()),
|
||||
src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(),
|
||||
gradient_input.device(), means.device(), invstds.device(), gamma.device(),
|
||||
dmeans.device(), dvars.device(), eps, src.num_samples(), num);
|
||||
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
|
@ -357,7 +357,9 @@ namespace dlib
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
|
|
@ -684,13 +684,15 @@ namespace dlib { namespace tt
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
|
||||
cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars);
|
||||
#else
|
||||
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
|
||||
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -814,13 +814,13 @@ namespace dlib { namespace tt
|
|||
/*!
|
||||
requires
|
||||
- eps > 0
|
||||
- src.num_samples() == gamma.size() == beta.size()
|
||||
- src.k() == gamma.size() == beta.size()
|
||||
- gamma.num_samples() == gamma.nr() == gamma.nc() == 1
|
||||
- have_same_dimensions(gamma, beta) == true
|
||||
- beta.num_samples() ==beta.nr() ==gamma.nc() == 1
|
||||
ensures
|
||||
- have_same_dimensions(#dest, src) == true
|
||||
- #means.size() == invstds.size() == src.num_samples()
|
||||
- #dest == the normalized version of src.
|
||||
- #dest == the normalized version of src, sample-wise.
|
||||
- #means == the mean values of the contents of src.
|
||||
- #invstds == 1/(the standard deviation values of the contents of src).
|
||||
!*/
|
||||
|
@ -834,7 +834,9 @@ namespace dlib { namespace tt
|
|||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
tensor& beta_grad,
|
||||
resizable_tensor& dmeans,
|
||||
resizable_tensor& dvars
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
|
@ -847,8 +849,6 @@ namespace dlib { namespace tt
|
|||
- have_same_dimensions(gamma, beta_grad) == true
|
||||
- means.size() == src.num_samples()
|
||||
- invstds.size() == src.num_samples()
|
||||
- have_same_dimensions(means, gamma) == true
|
||||
- have_same_dimensions(invstds, gamma) == true
|
||||
ensures
|
||||
- Let f(src,gamma,beta) == dot(gradient_input, dest output of
|
||||
layer_normalize(eps,dest,means,invstds,src,gamma,beta))
|
||||
|
|
|
@ -1403,7 +1403,7 @@ namespace dlib
|
|||
template <typename SUBNET>
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
|
||||
gamma = alias_tensor(1, sub.get_output().k());
|
||||
beta = gamma;
|
||||
|
||||
params.set_size(gamma.size()+beta.size());
|
||||
|
@ -1426,7 +1426,7 @@ namespace dlib
|
|||
auto g = gamma(params, 0);
|
||||
auto g_grad = gamma(params_grad, 0);
|
||||
auto b_grad = beta(params_grad, gamma.size());
|
||||
tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad);
|
||||
tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad, dmeans, dvars);
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; };
|
||||
|
@ -1493,6 +1493,7 @@ namespace dlib
|
|||
resizable_tensor params;
|
||||
alias_tensor gamma, beta;
|
||||
resizable_tensor means, invstds;
|
||||
resizable_tensor dmeans, dvars;
|
||||
double learning_rate_multiplier;
|
||||
double weight_decay_multiplier;
|
||||
double bias_learning_rate_multiplier;
|
||||
|
|
|
@ -607,7 +607,7 @@ namespace
|
|||
tt::tensor_rand rnd(0);
|
||||
rnd.fill_uniform(x);
|
||||
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
|
||||
resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc());
|
||||
resizable_tensor gamma(1, x.k(), 1, 1), beta(1, x.k(), 1, 1);
|
||||
gamma = 1;
|
||||
beta = 0;
|
||||
const float eps = 1e-5;
|
||||
|
@ -639,16 +639,19 @@ namespace
|
|||
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
|
||||
resizable_tensor gradient_input(x);
|
||||
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc());
|
||||
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc());
|
||||
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), 1, 1), beta_grad_cpu(1, x.k(), 1, 1);
|
||||
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), 1, 1), beta_grad_cuda(1, x.k(), 1, 1);
|
||||
resizable_tensor dmeans_cpu, dvars_cpu, dmeans_cuda, dvars_cuda;
|
||||
rnd.fill_gaussian(gradient_input);
|
||||
src_grad_cpu = 0;
|
||||
src_grad_cuda = 0;
|
||||
cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu);
|
||||
cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda);
|
||||
cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu, dmeans_cpu, dvars_cpu);
|
||||
cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda, dmeans_cuda, dvars_cuda);
|
||||
DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(beta_grad_cpu) - mat(beta_grad_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(dmeans_cpu) - mat(dmeans_cuda))) < 1e-4);
|
||||
DLIB_TEST(max(abs(mat(dvars_cpu) - mat(dvars_cuda))) < 1e-4);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue