Added prototypes for batch normalization's gradients.

This commit is contained in:
Davis King 2015-11-07 19:59:31 -05:00
parent 51ebcfc7f4
commit 6539ea6779
1 changed files with 46 additions and 1 deletions

View File

@ -52,7 +52,6 @@ namespace dlib
// -----------------------------------------------------------------------------------
// TODO, add versions of batch_normalize() that output the gradients.
void batch_normalize (
resizable_tensor& dest,
resizable_tensor& means,
@ -63,6 +62,7 @@ namespace dlib
);
/*!
requires
- src.num_samples() > 1
- gamma.num_samples() == 1
- beta.num_samples() == 1
- gamma.nr() == beta.nr() == src.nr()
@ -80,6 +80,39 @@ namespace dlib
- #vars == the variance values of the contents of src.
!*/
void batch_normalize_gradient (
const tensor& gradient_input,
const tensor& means,
const tensor& vars,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
/*!
requires
- vars and means should be the output of a call to
batch_normalize(dest,means,vars,src,gamma,beta)
- have_same_dimensions(gradient_input, src) == true
- have_same_dimensions(src, src_grad) == true
- src.num_samples() > 1
- gamma.num_samples() == 1
- have_same_dimensions(gamma, gamma_grad) == true
- have_same_dimensions(gamma, beta_grad) == true
- gamma.nr() == src.nr()
- gamma.nc() == src.nc()
- gamma.k() == src.k()
- have_same_dimensions(means, gamma) == true
- have_same_dimensions(vars, gamma) == true
ensures
- Let f(src,gamma,beta) == dot(gradient_input, dest output of
batch_normalize(dest,means,vars,src,gamma,beta))
- Adds the gradient of f() with respect to src to #src
- Adds the gradient of f() with respect to gamma to #gamma
- Adds the gradient of f() with respect to beta to #beta
!*/
void batch_normalize_conv (
resizable_tensor& dest,
resizable_tensor& means,
@ -90,6 +123,7 @@ namespace dlib
);
/*!
requires
- src.num_samples() > 1
- gamma.num_samples()==gamma.nr()==gamma.nc() == 1
- beta.num_samples() ==beta.nr() ==gamma.nc() == 1
- gamma.k() == beta.k() == src.k()
@ -103,6 +137,17 @@ namespace dlib
- #vars == the variance values of the contents of src.
!*/
void batch_normalize_conv_gradient (
const tensor& gradient_input,
const tensor& means,
const tensor& vars,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
class dropout