mirror of https://github.com/davisking/dlib.git
Made computed_output an optional argument to backward_inplace() so there is
symmetry between the non-inplace version. This also enables additional optimizations in the resulting network.
This commit is contained in:
parent
122f2fa6b5
commit
28475b8d0d
|
@ -113,6 +113,15 @@ namespace dlib
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
constexpr auto backward_requires_forward_output(
|
||||
layer_type& layer,
|
||||
SUBNET& sub
|
||||
) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
constexpr auto has_inplace_backward(
|
||||
layer_type& layer,
|
||||
|
@ -140,6 +149,15 @@ namespace dlib
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
constexpr auto has_inplace_backward(
|
||||
layer_type& layer,
|
||||
SUBNET& sub
|
||||
) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
constexpr auto is_inplace_layer(
|
||||
layer_type& layer,
|
||||
|
@ -194,6 +212,18 @@ namespace dlib
|
|||
layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad);
|
||||
}
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
auto call_layer_backward(
|
||||
layer_type& layer,
|
||||
const tensor& ,
|
||||
const tensor& gradient_input,
|
||||
SUBNET& sub,
|
||||
tensor& params_grad
|
||||
) -> decltype(layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad))
|
||||
{
|
||||
layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad);
|
||||
}
|
||||
|
||||
|
||||
template <typename layer_type, typename SUBNET>
|
||||
auto call_layer_forward(
|
||||
|
|
|
@ -559,7 +559,6 @@ namespace dlib
|
|||
}
|
||||
|
||||
void backward_inplace(
|
||||
const tensor& /*computed_output*/,
|
||||
const tensor& gradient_input,
|
||||
tensor& data_grad,
|
||||
tensor& /*params_grad*/
|
||||
|
|
|
@ -99,7 +99,7 @@ namespace dlib
|
|||
to document the interface that a layer object must implement.
|
||||
|
||||
The central work of defining a layer is implementing the forward and backward
|
||||
methods. When you do this you have three options:
|
||||
methods. When you do this you have four options:
|
||||
- Implement the forward() and backward() methods according to the
|
||||
specification shown below. Do not implement forward_inplace() and
|
||||
backward_inplace().
|
||||
|
@ -113,6 +113,12 @@ namespace dlib
|
|||
according to the specification shown below. Do not implement
|
||||
forward() and backward(). These in-place methods allow some types of
|
||||
layers to be implemented more efficiently.
|
||||
- Implement the forward_inplace() and backward_inplace() methods
|
||||
according to the specification shown below, except exclude the
|
||||
computed_output parameter from backward_inplace(). Doing this will
|
||||
allow dlib to make some layers execute in-place and therefore run a
|
||||
little faster and use less memory. Do not implement forward() and
|
||||
backward().
|
||||
!*/
|
||||
|
||||
public:
|
||||
|
@ -239,7 +245,7 @@ namespace dlib
|
|||
!*/
|
||||
|
||||
void backward_inplace(
|
||||
const tensor& computed_output,
|
||||
const tensor& computed_output, // this parameter is optional
|
||||
const tensor& gradient_input,
|
||||
tensor& data_grad,
|
||||
tensor& params_grad
|
||||
|
@ -503,7 +509,7 @@ namespace dlib
|
|||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
void forward_inplace(const tensor& input, tensor& output);
|
||||
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
const tensor& get_layer_params() const;
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
|
|
Loading…
Reference in New Issue