mirror of https://github.com/davisking/dlib.git
Added prototypes for batch normalization's gradients.
This commit is contained in:
parent
51ebcfc7f4
commit
6539ea6779
|
@ -52,7 +52,6 @@ namespace dlib
|
|||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
// TODO, add versions of batch_normalize() that output the gradients.
|
||||
void batch_normalize (
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& means,
|
||||
|
@ -63,6 +62,7 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
requires
|
||||
- src.num_samples() > 1
|
||||
- gamma.num_samples() == 1
|
||||
- beta.num_samples() == 1
|
||||
- gamma.nr() == beta.nr() == src.nr()
|
||||
|
@ -80,6 +80,39 @@ namespace dlib
|
|||
- #vars == the variance values of the contents of src.
|
||||
!*/
|
||||
|
||||
void batch_normalize_gradient (
|
||||
const tensor& gradient_input,
|
||||
const tensor& means,
|
||||
const tensor& vars,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- vars and means should be the output of a call to
|
||||
batch_normalize(dest,means,vars,src,gamma,beta)
|
||||
- have_same_dimensions(gradient_input, src) == true
|
||||
- have_same_dimensions(src, src_grad) == true
|
||||
- src.num_samples() > 1
|
||||
- gamma.num_samples() == 1
|
||||
- have_same_dimensions(gamma, gamma_grad) == true
|
||||
- have_same_dimensions(gamma, beta_grad) == true
|
||||
- gamma.nr() == src.nr()
|
||||
- gamma.nc() == src.nc()
|
||||
- gamma.k() == src.k()
|
||||
- have_same_dimensions(means, gamma) == true
|
||||
- have_same_dimensions(vars, gamma) == true
|
||||
ensures
|
||||
- Let f(src,gamma,beta) == dot(gradient_input, dest output of
|
||||
batch_normalize(dest,means,vars,src,gamma,beta))
|
||||
- Adds the gradient of f() with respect to src to #src
|
||||
- Adds the gradient of f() with respect to gamma to #gamma
|
||||
- Adds the gradient of f() with respect to beta to #beta
|
||||
!*/
|
||||
|
||||
void batch_normalize_conv (
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& means,
|
||||
|
@ -90,6 +123,7 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
requires
|
||||
- src.num_samples() > 1
|
||||
- gamma.num_samples()==gamma.nr()==gamma.nc() == 1
|
||||
- beta.num_samples() ==beta.nr() ==gamma.nc() == 1
|
||||
- gamma.k() == beta.k() == src.k()
|
||||
|
@ -103,6 +137,17 @@ namespace dlib
|
|||
- #vars == the variance values of the contents of src.
|
||||
!*/
|
||||
|
||||
void batch_normalize_conv_gradient (
|
||||
const tensor& gradient_input,
|
||||
const tensor& means,
|
||||
const tensor& vars,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
tensor& beta_grad
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
class dropout
|
||||
|
|
Loading…
Reference in New Issue