Add fuse layers for conv+affine+relu and conv+relu (#2842)

* Add fuse layers for conv+affine+relu and conv+relu

* Add relu to tensor_conv for cpu

* Update convolution serialization

* Move disable_duplicative_biases documentation from layers_abstract to visitors_abstract

* Fix convolution copy

* Update dlib/dnn/layers_abstract.h

---------

Co-authored-by: Facundo Galan <fgalan@danaide.com.ar>
Co-authored-by: Davis E. King <davis685@gmail.com>
This commit is contained in:
Facundo Galán 2023-08-05 13:38:29 -03:00 committed by GitHub
parent efae642813
commit be2fa7f93c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 195 additions and 47 deletions

View File

@ -2647,12 +2647,14 @@ namespace dlib
resizable_tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
)
{
DLIB_CASSERT(filters.num_samples() == biases.k());
(*this)(add_to_output, output,data,filters);
tt::add(1, output, 1, biases);
if (use_relu) tt::relu(output, output);
}
void tensor_conv::operator() (
@ -2660,12 +2662,14 @@ namespace dlib
tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
)
{
DLIB_CASSERT(filters.num_samples() == biases.k());
(*this)(add_to_output, output, data, filters);
tt::add(1, output, 1, biases);
if (use_relu) tt::relu(output, output);
}

View File

@ -603,7 +603,8 @@ namespace dlib
resizable_tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
);
void operator() (
@ -611,7 +612,8 @@ namespace dlib
tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
);
void get_gradient_for_data (

View File

@ -1160,13 +1160,14 @@ namespace dlib
resizable_tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
)
{
DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
output.set_size(out_num_samples, out_k, out_nr, out_nc);
(*this)(add_to_output, static_cast<tensor&>(output), data, filters, biases);
(*this)(add_to_output, static_cast<tensor&>(output), data, filters, biases, use_relu);
}
void tensor_conv::operator() (
@ -1174,14 +1175,16 @@ namespace dlib
tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
)
{
// Function cudnnConvolutionBiasActivationForward should only be called with CUDNN_ACTIVATION_IDENTITY when
// the chosen forward algorithm is CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, as cuDNN documentation explicitly says.
// In case the algorithm is different, perform the forward pass and bias addition separately.
if (forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
// If use_relu is true, any algorithm can be used.
if (!use_relu && forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
{
(*this)(add_to_output, output, data, filters);
@ -1245,7 +1248,7 @@ namespace dlib
out,
descriptor(biases),
biases.device(),
identity_activation_descriptor(),
use_relu ? relu_activation_descriptor() : identity_activation_descriptor(),
out_desc,
out));
}

View File

@ -190,7 +190,8 @@ namespace dlib
resizable_tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
);
void operator() (
@ -198,7 +199,8 @@ namespace dlib
tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
const tensor& biases,
bool use_relu
);
void get_gradient_for_data (

View File

@ -1047,8 +1047,9 @@ namespace dlib { namespace tt
tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
) { impl(add_to_output,output,data,filters,biases); }
const tensor& biases,
bool use_relu
) { impl(add_to_output,output,data,filters,biases,use_relu); }
/*!
requires
- setup() has been called. Specifically, setup() has been called like this:
@ -1069,6 +1070,8 @@ namespace dlib { namespace tt
previous values in output.
- Adds biases to the result of the convolved data
- filters contains filters.num_samples() filters.
- If use_relu==true, then a relu activation will be applied to the result
of convolution+bias.
!*/
void operator() (
@ -1076,8 +1079,9 @@ namespace dlib { namespace tt
resizable_tensor& output,
const tensor& data,
const tensor& filters,
const tensor& biases
) { impl(add_to_output,output,data,filters, biases); }
const tensor& biases,
bool use_relu
) { impl(add_to_output,output,data,filters,biases,use_relu); }
/*!
requires
- setup() has been called. Specifically, setup() has been called like this:

View File

@ -60,7 +60,8 @@ namespace dlib
num_filters_(o.num_outputs),
padding_y_(_padding_y),
padding_x_(_padding_x),
use_bias(true)
use_bias(true),
use_relu(false)
{
DLIB_CASSERT(num_filters_ > 0);
}
@ -107,7 +108,21 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
bool relu_is_disabled() const { return !use_relu; }
void disable_relu()
{
use_relu = false;
}
void enable_relu()
{
use_relu = true;
}
bool bias_is_disabled() const { return !use_bias; }
void disable_bias()
{
if (use_bias == false)
@ -123,6 +138,7 @@ namespace dlib
std::copy(temp.begin(), temp.end() - num_filters_, params.begin());
biases = alias_tensor();
}
void enable_bias()
{
if (use_bias == true)
@ -171,7 +187,8 @@ namespace dlib
num_filters_(item.num_filters_),
padding_y_(item.padding_y_),
padding_x_(item.padding_x_),
use_bias(item.use_bias)
use_bias(item.use_bias),
use_relu(item.use_relu)
{
// this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error.
@ -197,6 +214,7 @@ namespace dlib
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_;
use_bias = item.use_bias;
use_relu = item.use_relu;
return *this;
}
@ -238,7 +256,8 @@ namespace dlib
conv(false, output,
sub.get_output(),
filters(params,0),
biases(params, filters.size()));
biases(params, filters.size()),
use_relu);
}
else
{
@ -270,7 +289,7 @@ namespace dlib
friend void serialize(const con_& item, std::ostream& out)
{
serialize("con_5", out);
serialize("con_6", out);
serialize(item.params, out);
serialize(item.num_filters_, out);
serialize(_nr, out);
@ -286,6 +305,7 @@ namespace dlib
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
serialize(item.use_relu, out);
}
friend void deserialize(con_& item, std::istream& in)
@ -296,7 +316,7 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "con_4" || version == "con_5")
if (version == "con_4" || version == "con_5" || version == "con_6")
{
deserialize(item.params, in);
deserialize(item.num_filters_, in);
@ -318,10 +338,14 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
if (version == "con_5")
if (version == "con_5" || version == "con_6")
{
deserialize(item.use_bias, in);
}
if (version == "con_6")
{
deserialize(item.use_relu, in);
}
}
else
{
@ -352,6 +376,10 @@ namespace dlib
{
out << " use_bias=false";
}
if (item.use_relu)
{
out << " use_relu="<< std::boolalpha << item.use_relu;
}
return out;
}
@ -369,7 +397,9 @@ namespace dlib
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'"
<< " use_relu='"<<(item.use_relu?"true":"false")<<"'"
<< ">\n";
out << mat(item.params);
out << "</con>\n";
}
@ -391,6 +421,7 @@ namespace dlib
int padding_y_;
int padding_x_;
bool use_bias;
bool use_relu;
};
template <
@ -2962,10 +2993,18 @@ namespace dlib
class relu_
{
public:
relu_()
relu_()
{
}
void disable()
{
params.clear();
disabled = true;
}
bool is_disabled() const { return disabled; }
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
@ -2973,6 +3012,9 @@ namespace dlib
void forward_inplace(const tensor& input, tensor& output)
{
if (disabled)
return;
tt::relu(output, input);
}
@ -2983,6 +3025,9 @@ namespace dlib
tensor&
)
{
if (disabled)
return;
tt::relu_gradient(data_grad, computed_output, gradient_input);
}
@ -2992,32 +3037,48 @@ namespace dlib
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const relu_& /*item*/, std::ostream& out)
friend void serialize(const relu_& item, std::ostream& out)
{
serialize("relu_", out);
serialize("relu_2", out);
serialize(item.disabled, out);
}
friend void deserialize(relu_& /*item*/, std::istream& in)
friend void deserialize(relu_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "relu_")
if (version == "relu_2")
{
deserialize(item.disabled, in);
return;
}
if (version != "relu_" && version != "relu_2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
}
friend std::ostream& operator<<(std::ostream& out, const relu_& /*item*/)
friend std::ostream& operator<<(std::ostream& out, const relu_& item)
{
out << "relu";
if (item.disabled)
{
out << "\t (disabled)";
}
return out;
}
friend void to_xml(const relu_& /*item*/, std::ostream& out)
friend void to_xml(const relu_& item, std::ostream& out)
{
out << "<relu/>\n";
out << "<relu";
if (item.disabled)
{
out << " disabled='"<< std::boolalpha << item.disabled << "'";
}
out << "/>\n";
}
private:
resizable_tensor params;
bool disabled = false;
};

View File

@ -939,6 +939,28 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val
!*/
void disable_relu(
);
/*!
ensures
- relu_is_disabled() returns true
!*/
void enable_relu(
);
/*!
ensures
- relu_is_disabled() returns false
!*/
bool relu_is_disabled(
) const;
/*!
ensures
- returns true if relu is disabled for this layer. This means no activation function
will be applied after the convolution when calling forward.
!*/
void disable_bias(
);
/*!
@ -1821,22 +1843,6 @@ namespace dlib
template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename net_type>
void disable_duplicative_biases (
const net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Disables bias for all bn_ and layer_norm_ inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ and layer_norm_ inputs.
!*/
// ----------------------------------------------------------------------------------------
class affine_
@ -2274,6 +2280,15 @@ namespace dlib
relu_(
);
void disable(
);
/*!
ensures
- #get_layer_params().size() == 0.
- when forward_inplace and backward_inplace are called, they return immediately doing nothing.
Causing this layer to trivially perform the an identity transform.
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);

View File

@ -3,6 +3,7 @@
#ifndef DLIB_DNn_VISITORS_H_
#define DLIB_DNn_VISITORS_H_
#include "visitors_abstract.h"
#include "input.h"
#include "layers.h"
#include "loss.h"
@ -401,7 +402,42 @@ namespace dlib
// disable other layer types
}
// handle the standard case (convolutional layer followed by affine;
// handle the case of convolutional layer followed by relu
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename R>
void fuse_convolution(add_layer<relu_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, R>& l)
{
if (l.layer_details().is_disabled())
return;
// get the convolution below the relu layer
auto& conv = l.subnet().layer_details();
conv.enable_relu();
// disable the relu layer
l.layer_details().disable();
}
// handle the case of convolutional layer followed by affine followed by relu
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E, typename R>
void fuse_convolution(add_layer<relu_, add_layer<affine_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, E>, R>& l)
{
if (l.layer_details().is_disabled())
return;
// fuse the convolutional layer followed by affine
fuse_convolution(l.subnet());
// get the convolution below the affine layer
auto& conv = l.subnet().subnet().layer_details();
conv.enable_relu();
// disable the relu layer
l.layer_details().disable();
}
// handle the case of convolutional layer followed by affine
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
void fuse_convolution(add_layer<affine_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, E>& l)
{

View File

@ -27,6 +27,22 @@ namespace dlib
new_window_size.
!*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
void disable_duplicative_biases (
const net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Disables bias for all bn_ and layer_norm_ inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ and layer_norm_ inputs.
!*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
@ -42,6 +58,11 @@ namespace dlib
- Disables all the affine_ layers that have a convolution as an input.
- Updates the convolution weights beneath the affine_ layers to produce the same
output as with the affine_ layers enabled.
- Disables all the relu_ layers that have a convolution as input.
- Disables all the relu_ layers that have an affine_ layer as input, with a
convolution as input.
- Updates the convolution to apply a relu activation function, to produce the same
output as with the relu_ layer enabled.
!*/
// ----------------------------------------------------------------------------------------