mirror of https://github.com/davisking/dlib.git
Add fuse layers for conv+affine+relu and conv+relu (#2842)
* Add fuse layers for conv+affine+relu and conv+relu * Add relu to tensor_conv for cpu * Update convolution serialization * Move disable_duplicative_biases documentation from layers_abstract to visitors_abstract * Fix convolution copy * Update dlib/dnn/layers_abstract.h --------- Co-authored-by: Facundo Galan <fgalan@danaide.com.ar> Co-authored-by: Davis E. King <davis685@gmail.com>
This commit is contained in:
parent
efae642813
commit
be2fa7f93c
|
@ -2647,12 +2647,14 @@ namespace dlib
|
|||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(filters.num_samples() == biases.k());
|
||||
(*this)(add_to_output, output,data,filters);
|
||||
tt::add(1, output, 1, biases);
|
||||
if (use_relu) tt::relu(output, output);
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
|
@ -2660,12 +2662,14 @@ namespace dlib
|
|||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(filters.num_samples() == biases.k());
|
||||
(*this)(add_to_output, output, data, filters);
|
||||
tt::add(1, output, 1, biases);
|
||||
if (use_relu) tt::relu(output, output);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -603,7 +603,8 @@ namespace dlib
|
|||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
);
|
||||
|
||||
void operator() (
|
||||
|
@ -611,7 +612,8 @@ namespace dlib
|
|||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
);
|
||||
|
||||
void get_gradient_for_data (
|
||||
|
|
|
@ -1160,13 +1160,14 @@ namespace dlib
|
|||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
|
||||
|
||||
output.set_size(out_num_samples, out_k, out_nr, out_nc);
|
||||
(*this)(add_to_output, static_cast<tensor&>(output), data, filters, biases);
|
||||
(*this)(add_to_output, static_cast<tensor&>(output), data, filters, biases, use_relu);
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
|
@ -1174,14 +1175,16 @@ namespace dlib
|
|||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
)
|
||||
{
|
||||
|
||||
// Function cudnnConvolutionBiasActivationForward should only be called with CUDNN_ACTIVATION_IDENTITY when
|
||||
// the chosen forward algorithm is CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, as cuDNN documentation explicitly says.
|
||||
// In case the algorithm is different, perform the forward pass and bias addition separately.
|
||||
if (forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
|
||||
// If use_relu is true, any algorithm can be used.
|
||||
if (!use_relu && forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
|
||||
{
|
||||
(*this)(add_to_output, output, data, filters);
|
||||
|
||||
|
@ -1245,7 +1248,7 @@ namespace dlib
|
|||
out,
|
||||
descriptor(biases),
|
||||
biases.device(),
|
||||
identity_activation_descriptor(),
|
||||
use_relu ? relu_activation_descriptor() : identity_activation_descriptor(),
|
||||
out_desc,
|
||||
out));
|
||||
}
|
||||
|
|
|
@ -190,7 +190,8 @@ namespace dlib
|
|||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
);
|
||||
|
||||
void operator() (
|
||||
|
@ -198,7 +199,8 @@ namespace dlib
|
|||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
);
|
||||
|
||||
void get_gradient_for_data (
|
||||
|
|
|
@ -1047,8 +1047,9 @@ namespace dlib { namespace tt
|
|||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
) { impl(add_to_output,output,data,filters,biases); }
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
) { impl(add_to_output,output,data,filters,biases,use_relu); }
|
||||
/*!
|
||||
requires
|
||||
- setup() has been called. Specifically, setup() has been called like this:
|
||||
|
@ -1069,6 +1070,8 @@ namespace dlib { namespace tt
|
|||
previous values in output.
|
||||
- Adds biases to the result of the convolved data
|
||||
- filters contains filters.num_samples() filters.
|
||||
- If use_relu==true, then a relu activation will be applied to the result
|
||||
of convolution+bias.
|
||||
!*/
|
||||
|
||||
void operator() (
|
||||
|
@ -1076,8 +1079,9 @@ namespace dlib { namespace tt
|
|||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
) { impl(add_to_output,output,data,filters, biases); }
|
||||
const tensor& biases,
|
||||
bool use_relu
|
||||
) { impl(add_to_output,output,data,filters,biases,use_relu); }
|
||||
/*!
|
||||
requires
|
||||
- setup() has been called. Specifically, setup() has been called like this:
|
||||
|
|
|
@ -60,7 +60,8 @@ namespace dlib
|
|||
num_filters_(o.num_outputs),
|
||||
padding_y_(_padding_y),
|
||||
padding_x_(_padding_x),
|
||||
use_bias(true)
|
||||
use_bias(true),
|
||||
use_relu(false)
|
||||
{
|
||||
DLIB_CASSERT(num_filters_ > 0);
|
||||
}
|
||||
|
@ -107,7 +108,21 @@ namespace dlib
|
|||
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
|
||||
bool relu_is_disabled() const { return !use_relu; }
|
||||
|
||||
void disable_relu()
|
||||
{
|
||||
use_relu = false;
|
||||
}
|
||||
|
||||
void enable_relu()
|
||||
{
|
||||
use_relu = true;
|
||||
}
|
||||
|
||||
bool bias_is_disabled() const { return !use_bias; }
|
||||
|
||||
void disable_bias()
|
||||
{
|
||||
if (use_bias == false)
|
||||
|
@ -123,6 +138,7 @@ namespace dlib
|
|||
std::copy(temp.begin(), temp.end() - num_filters_, params.begin());
|
||||
biases = alias_tensor();
|
||||
}
|
||||
|
||||
void enable_bias()
|
||||
{
|
||||
if (use_bias == true)
|
||||
|
@ -171,7 +187,8 @@ namespace dlib
|
|||
num_filters_(item.num_filters_),
|
||||
padding_y_(item.padding_y_),
|
||||
padding_x_(item.padding_x_),
|
||||
use_bias(item.use_bias)
|
||||
use_bias(item.use_bias),
|
||||
use_relu(item.use_relu)
|
||||
{
|
||||
// this->conv is non-copyable and basically stateless, so we have to write our
|
||||
// own copy to avoid trying to copy it and getting an error.
|
||||
|
@ -197,6 +214,7 @@ namespace dlib
|
|||
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
|
||||
num_filters_ = item.num_filters_;
|
||||
use_bias = item.use_bias;
|
||||
use_relu = item.use_relu;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -238,7 +256,8 @@ namespace dlib
|
|||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0),
|
||||
biases(params, filters.size()));
|
||||
biases(params, filters.size()),
|
||||
use_relu);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -270,7 +289,7 @@ namespace dlib
|
|||
|
||||
friend void serialize(const con_& item, std::ostream& out)
|
||||
{
|
||||
serialize("con_5", out);
|
||||
serialize("con_6", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.num_filters_, out);
|
||||
serialize(_nr, out);
|
||||
|
@ -286,6 +305,7 @@ namespace dlib
|
|||
serialize(item.bias_learning_rate_multiplier, out);
|
||||
serialize(item.bias_weight_decay_multiplier, out);
|
||||
serialize(item.use_bias, out);
|
||||
serialize(item.use_relu, out);
|
||||
}
|
||||
|
||||
friend void deserialize(con_& item, std::istream& in)
|
||||
|
@ -296,7 +316,7 @@ namespace dlib
|
|||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
if (version == "con_4" || version == "con_5")
|
||||
if (version == "con_4" || version == "con_5" || version == "con_6")
|
||||
{
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.num_filters_, in);
|
||||
|
@ -318,10 +338,14 @@ namespace dlib
|
|||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
|
||||
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
|
||||
if (version == "con_5")
|
||||
if (version == "con_5" || version == "con_6")
|
||||
{
|
||||
deserialize(item.use_bias, in);
|
||||
}
|
||||
if (version == "con_6")
|
||||
{
|
||||
deserialize(item.use_relu, in);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -352,6 +376,10 @@ namespace dlib
|
|||
{
|
||||
out << " use_bias=false";
|
||||
}
|
||||
if (item.use_relu)
|
||||
{
|
||||
out << " use_relu="<< std::boolalpha << item.use_relu;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -369,7 +397,9 @@ namespace dlib
|
|||
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
|
||||
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
|
||||
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
|
||||
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'"
|
||||
<< " use_relu='"<<(item.use_relu?"true":"false")<<"'"
|
||||
<< ">\n";
|
||||
out << mat(item.params);
|
||||
out << "</con>\n";
|
||||
}
|
||||
|
@ -391,6 +421,7 @@ namespace dlib
|
|||
int padding_y_;
|
||||
int padding_x_;
|
||||
bool use_bias;
|
||||
bool use_relu;
|
||||
};
|
||||
|
||||
template <
|
||||
|
@ -2962,10 +2993,18 @@ namespace dlib
|
|||
class relu_
|
||||
{
|
||||
public:
|
||||
relu_()
|
||||
relu_()
|
||||
{
|
||||
}
|
||||
|
||||
void disable()
|
||||
{
|
||||
params.clear();
|
||||
disabled = true;
|
||||
}
|
||||
|
||||
bool is_disabled() const { return disabled; }
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& /*sub*/)
|
||||
{
|
||||
|
@ -2973,6 +3012,9 @@ namespace dlib
|
|||
|
||||
void forward_inplace(const tensor& input, tensor& output)
|
||||
{
|
||||
if (disabled)
|
||||
return;
|
||||
|
||||
tt::relu(output, input);
|
||||
}
|
||||
|
||||
|
@ -2983,6 +3025,9 @@ namespace dlib
|
|||
tensor&
|
||||
)
|
||||
{
|
||||
if (disabled)
|
||||
return;
|
||||
|
||||
tt::relu_gradient(data_grad, computed_output, gradient_input);
|
||||
}
|
||||
|
||||
|
@ -2992,32 +3037,48 @@ namespace dlib
|
|||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const relu_& /*item*/, std::ostream& out)
|
||||
friend void serialize(const relu_& item, std::ostream& out)
|
||||
{
|
||||
serialize("relu_", out);
|
||||
serialize("relu_2", out);
|
||||
serialize(item.disabled, out);
|
||||
}
|
||||
|
||||
friend void deserialize(relu_& /*item*/, std::istream& in)
|
||||
friend void deserialize(relu_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "relu_")
|
||||
if (version == "relu_2")
|
||||
{
|
||||
deserialize(item.disabled, in);
|
||||
return;
|
||||
}
|
||||
if (version != "relu_" && version != "relu_2")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const relu_& /*item*/)
|
||||
friend std::ostream& operator<<(std::ostream& out, const relu_& item)
|
||||
{
|
||||
out << "relu";
|
||||
if (item.disabled)
|
||||
{
|
||||
out << "\t (disabled)";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const relu_& /*item*/, std::ostream& out)
|
||||
friend void to_xml(const relu_& item, std::ostream& out)
|
||||
{
|
||||
out << "<relu/>\n";
|
||||
out << "<relu";
|
||||
if (item.disabled)
|
||||
{
|
||||
out << " disabled='"<< std::boolalpha << item.disabled << "'";
|
||||
}
|
||||
out << "/>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params;
|
||||
bool disabled = false;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -939,6 +939,28 @@ namespace dlib
|
|||
- #get_bias_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
void disable_relu(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- relu_is_disabled() returns true
|
||||
!*/
|
||||
|
||||
void enable_relu(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- relu_is_disabled() returns false
|
||||
!*/
|
||||
|
||||
bool relu_is_disabled(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns true if relu is disabled for this layer. This means no activation function
|
||||
will be applied after the convolution when calling forward.
|
||||
!*/
|
||||
|
||||
void disable_bias(
|
||||
);
|
||||
/*!
|
||||
|
@ -1821,22 +1843,6 @@ namespace dlib
|
|||
template <typename SUBNET>
|
||||
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
void disable_duplicative_biases (
|
||||
const net_type& net
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
ensures
|
||||
- Disables bias for all bn_ and layer_norm_ inputs.
|
||||
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
|
||||
to zero of all bn_ and layer_norm_ inputs.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class affine_
|
||||
|
@ -2274,6 +2280,15 @@ namespace dlib
|
|||
relu_(
|
||||
);
|
||||
|
||||
void disable(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_layer_params().size() == 0.
|
||||
- when forward_inplace and backward_inplace are called, they return immediately doing nothing.
|
||||
Causing this layer to trivially perform the an identity transform.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
void forward_inplace(const tensor& input, tensor& output);
|
||||
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#ifndef DLIB_DNn_VISITORS_H_
|
||||
#define DLIB_DNn_VISITORS_H_
|
||||
|
||||
#include "visitors_abstract.h"
|
||||
#include "input.h"
|
||||
#include "layers.h"
|
||||
#include "loss.h"
|
||||
|
@ -401,7 +402,42 @@ namespace dlib
|
|||
// disable other layer types
|
||||
}
|
||||
|
||||
// handle the standard case (convolutional layer followed by affine;
|
||||
// handle the case of convolutional layer followed by relu
|
||||
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename R>
|
||||
void fuse_convolution(add_layer<relu_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, R>& l)
|
||||
{
|
||||
if (l.layer_details().is_disabled())
|
||||
return;
|
||||
|
||||
// get the convolution below the relu layer
|
||||
auto& conv = l.subnet().layer_details();
|
||||
|
||||
conv.enable_relu();
|
||||
|
||||
// disable the relu layer
|
||||
l.layer_details().disable();
|
||||
}
|
||||
|
||||
// handle the case of convolutional layer followed by affine followed by relu
|
||||
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E, typename R>
|
||||
void fuse_convolution(add_layer<relu_, add_layer<affine_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, E>, R>& l)
|
||||
{
|
||||
if (l.layer_details().is_disabled())
|
||||
return;
|
||||
|
||||
// fuse the convolutional layer followed by affine
|
||||
fuse_convolution(l.subnet());
|
||||
|
||||
// get the convolution below the affine layer
|
||||
auto& conv = l.subnet().subnet().layer_details();
|
||||
|
||||
conv.enable_relu();
|
||||
|
||||
// disable the relu layer
|
||||
l.layer_details().disable();
|
||||
}
|
||||
|
||||
// handle the case of convolutional layer followed by affine
|
||||
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
|
||||
void fuse_convolution(add_layer<affine_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, E>& l)
|
||||
{
|
||||
|
|
|
@ -27,6 +27,22 @@ namespace dlib
|
|||
new_window_size.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
void disable_duplicative_biases (
|
||||
const net_type& net
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
ensures
|
||||
- Disables bias for all bn_ and layer_norm_ inputs.
|
||||
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
|
||||
to zero of all bn_ and layer_norm_ inputs.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
|
@ -42,6 +58,11 @@ namespace dlib
|
|||
- Disables all the affine_ layers that have a convolution as an input.
|
||||
- Updates the convolution weights beneath the affine_ layers to produce the same
|
||||
output as with the affine_ layers enabled.
|
||||
- Disables all the relu_ layers that have a convolution as input.
|
||||
- Disables all the relu_ layers that have an affine_ layer as input, with a
|
||||
convolution as input.
|
||||
- Updates the convolution to apply a relu activation function, to produce the same
|
||||
output as with the relu_ layer enabled.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue