mirror of https://github.com/davisking/dlib.git
Moved most of the layer parameters from runtime variables set in constructors
to template arguments. This way, the type of a network specifies the entire network architecture and most of the time the user doesn't even need to do anything with layer constructors.
This commit is contained in:
parent
001bca78e3
commit
fe168596a2
|
@ -19,31 +19,25 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _num_filters,
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class con_
|
||||
{
|
||||
public:
|
||||
|
||||
con_ (
|
||||
) :
|
||||
_num_filters(1),
|
||||
_nr(3),
|
||||
_nc(3),
|
||||
_stride_y(1),
|
||||
_stride_x(1)
|
||||
{}
|
||||
static_assert(_num_filters > 0, "The number of filters must be > 0");
|
||||
static_assert(_nr > 0, "The number of rows in a filter must be > 0");
|
||||
static_assert(_nc > 0, "The number of columns in a filter must be > 0");
|
||||
static_assert(_stride_y > 0, "The filter stride must be > 0");
|
||||
static_assert(_stride_x > 0, "The filter stride must be > 0");
|
||||
|
||||
con_(
|
||||
long num_filters_,
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
) :
|
||||
_num_filters(num_filters_),
|
||||
_nr(nr_),
|
||||
_nc(nc_),
|
||||
_stride_y(stride_y_),
|
||||
_stride_x(stride_x_)
|
||||
)
|
||||
{}
|
||||
|
||||
long num_filters() const { return _num_filters; }
|
||||
|
@ -56,11 +50,6 @@ namespace dlib
|
|||
const con_& item
|
||||
) :
|
||||
params(item.params),
|
||||
_num_filters(item._num_filters),
|
||||
_nr(item._nr),
|
||||
_nc(item._nc),
|
||||
_stride_y(item._stride_y),
|
||||
_stride_x(item._stride_x),
|
||||
filters(item.filters),
|
||||
biases(item.biases)
|
||||
{
|
||||
|
@ -78,11 +67,6 @@ namespace dlib
|
|||
// this->conv is non-copyable and basically stateless, so we have to write our
|
||||
// own copy to avoid trying to copy it and getting an error.
|
||||
params = item.params;
|
||||
_num_filters = item._num_filters;
|
||||
_nr = item._nr;
|
||||
_nc = item._nc;
|
||||
_stride_y = item._stride_y;
|
||||
_stride_x = item._stride_x;
|
||||
filters = item.filters;
|
||||
biases = item.biases;
|
||||
return *this;
|
||||
|
@ -135,11 +119,11 @@ namespace dlib
|
|||
{
|
||||
serialize("con_", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item._num_filters, out);
|
||||
serialize(item._nr, out);
|
||||
serialize(item._nc, out);
|
||||
serialize(item._stride_y, out);
|
||||
serialize(item._stride_x, out);
|
||||
serialize(_num_filters, out);
|
||||
serialize(_nr, out);
|
||||
serialize(_nc, out);
|
||||
serialize(_stride_y, out);
|
||||
serialize(_stride_x, out);
|
||||
serialize(item.filters, out);
|
||||
serialize(item.biases, out);
|
||||
}
|
||||
|
@ -151,57 +135,66 @@ namespace dlib
|
|||
if (version != "con_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::con_.");
|
||||
deserialize(item.params, in);
|
||||
deserialize(item._num_filters, in);
|
||||
deserialize(item._nr, in);
|
||||
deserialize(item._nc, in);
|
||||
deserialize(item._stride_y, in);
|
||||
deserialize(item._stride_x, in);
|
||||
|
||||
|
||||
long num_filters;
|
||||
long nr;
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
deserialize(num_filters, in);
|
||||
deserialize(nr, in);
|
||||
deserialize(nc, in);
|
||||
deserialize(stride_y, in);
|
||||
deserialize(stride_x, in);
|
||||
deserialize(item.filters, in);
|
||||
deserialize(item.biases, in);
|
||||
|
||||
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
|
||||
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
|
||||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
|
||||
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
resizable_tensor params;
|
||||
long _num_filters;
|
||||
long _nr;
|
||||
long _nc;
|
||||
int _stride_y;
|
||||
int _stride_x;
|
||||
alias_tensor filters, biases;
|
||||
|
||||
tt::tensor_conv conv;
|
||||
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using con = add_layer<con_, SUBNET>;
|
||||
template <
|
||||
long num_filters,
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class max_pool_
|
||||
{
|
||||
static_assert(_nr > 0, "The number of rows in a filter must be > 0");
|
||||
static_assert(_nc > 0, "The number of columns in a filter must be > 0");
|
||||
static_assert(_stride_y > 0, "The filter stride must be > 0");
|
||||
static_assert(_stride_x > 0, "The filter stride must be > 0");
|
||||
public:
|
||||
|
||||
max_pool_ (
|
||||
) :
|
||||
_nr(3),
|
||||
_nc(3),
|
||||
_stride_y(1),
|
||||
_stride_x(1)
|
||||
{}
|
||||
|
||||
max_pool_(
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
) :
|
||||
_nr(nr_),
|
||||
_nc(nc_),
|
||||
_stride_y(stride_y_),
|
||||
_stride_x(stride_x_)
|
||||
{}
|
||||
) {}
|
||||
|
||||
long nr() const { return _nr; }
|
||||
long nc() const { return _nc; }
|
||||
|
@ -209,12 +202,8 @@ namespace dlib
|
|||
long stride_x() const { return _stride_x; }
|
||||
|
||||
max_pool_ (
|
||||
const max_pool_& item
|
||||
) :
|
||||
_nr(item._nr),
|
||||
_nc(item._nc),
|
||||
_stride_y(item._stride_y),
|
||||
_stride_x(item._stride_x)
|
||||
const max_pool_&
|
||||
)
|
||||
{
|
||||
// this->mp is non-copyable so we have to write our own copy to avoid trying to
|
||||
// copy it and getting an error.
|
||||
|
@ -230,11 +219,6 @@ namespace dlib
|
|||
|
||||
// this->mp is non-copyable so we have to write our own copy to avoid trying to
|
||||
// copy it and getting an error.
|
||||
_nr = item._nr;
|
||||
_nc = item._nc;
|
||||
_stride_y = item._stride_y;
|
||||
_stride_x = item._stride_x;
|
||||
|
||||
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
|
||||
return *this;
|
||||
}
|
||||
|
@ -263,10 +247,10 @@ namespace dlib
|
|||
friend void serialize(const max_pool_& item, std::ostream& out)
|
||||
{
|
||||
serialize("max_pool_", out);
|
||||
serialize(item._nr, out);
|
||||
serialize(item._nc, out);
|
||||
serialize(item._stride_y, out);
|
||||
serialize(item._stride_x, out);
|
||||
serialize(_nr, out);
|
||||
serialize(_nc, out);
|
||||
serialize(_stride_y, out);
|
||||
serialize(_stride_x, out);
|
||||
}
|
||||
|
||||
friend void deserialize(max_pool_& item, std::istream& in)
|
||||
|
@ -275,53 +259,58 @@ namespace dlib
|
|||
deserialize(version, in);
|
||||
if (version != "max_pool_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::max_pool_.");
|
||||
deserialize(item._nr, in);
|
||||
deserialize(item._nc, in);
|
||||
deserialize(item._stride_y, in);
|
||||
deserialize(item._stride_x, in);
|
||||
|
||||
item.mp.setup_max_pooling(item._nr, item._nc, item._stride_y, item._stride_x);
|
||||
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
|
||||
|
||||
long nr;
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
|
||||
deserialize(nr, in);
|
||||
deserialize(nc, in);
|
||||
deserialize(stride_y, in);
|
||||
deserialize(stride_x, in);
|
||||
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
|
||||
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
|
||||
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
|
||||
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
long _nr;
|
||||
long _nc;
|
||||
int _stride_y;
|
||||
int _stride_x;
|
||||
|
||||
tt::pooling mp;
|
||||
resizable_tensor params;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using max_pool = add_layer<max_pool_, SUBNET>;
|
||||
template <
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class avg_pool_
|
||||
{
|
||||
public:
|
||||
|
||||
avg_pool_ (
|
||||
) :
|
||||
_nr(3),
|
||||
_nc(3),
|
||||
_stride_y(1),
|
||||
_stride_x(1)
|
||||
{}
|
||||
static_assert(_nr > 0, "The number of rows in a filter must be > 0");
|
||||
static_assert(_nc > 0, "The number of columns in a filter must be > 0");
|
||||
static_assert(_stride_y > 0, "The filter stride must be > 0");
|
||||
static_assert(_stride_x > 0, "The filter stride must be > 0");
|
||||
|
||||
avg_pool_(
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
) :
|
||||
_nr(nr_),
|
||||
_nc(nc_),
|
||||
_stride_y(stride_y_),
|
||||
_stride_x(stride_x_)
|
||||
{}
|
||||
) {}
|
||||
|
||||
long nr() const { return _nr; }
|
||||
long nc() const { return _nc; }
|
||||
|
@ -329,12 +318,8 @@ namespace dlib
|
|||
long stride_x() const { return _stride_x; }
|
||||
|
||||
avg_pool_ (
|
||||
const avg_pool_& item
|
||||
) :
|
||||
_nr(item._nr),
|
||||
_nc(item._nc),
|
||||
_stride_y(item._stride_y),
|
||||
_stride_x(item._stride_x)
|
||||
const avg_pool_&
|
||||
)
|
||||
{
|
||||
// this->ap is non-copyable so we have to write our own copy to avoid trying to
|
||||
// copy it and getting an error.
|
||||
|
@ -350,11 +335,6 @@ namespace dlib
|
|||
|
||||
// this->ap is non-copyable so we have to write our own copy to avoid trying to
|
||||
// copy it and getting an error.
|
||||
_nr = item._nr;
|
||||
_nc = item._nc;
|
||||
_stride_y = item._stride_y;
|
||||
_stride_x = item._stride_x;
|
||||
|
||||
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
|
||||
return *this;
|
||||
}
|
||||
|
@ -383,10 +363,10 @@ namespace dlib
|
|||
friend void serialize(const avg_pool_& item, std::ostream& out)
|
||||
{
|
||||
serialize("avg_pool_", out);
|
||||
serialize(item._nr, out);
|
||||
serialize(item._nc, out);
|
||||
serialize(item._stride_y, out);
|
||||
serialize(item._stride_x, out);
|
||||
serialize(_nr, out);
|
||||
serialize(_nc, out);
|
||||
serialize(_stride_y, out);
|
||||
serialize(_stride_x, out);
|
||||
}
|
||||
|
||||
friend void deserialize(avg_pool_& item, std::istream& in)
|
||||
|
@ -395,27 +375,38 @@ namespace dlib
|
|||
deserialize(version, in);
|
||||
if (version != "avg_pool_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::avg_pool_.");
|
||||
deserialize(item._nr, in);
|
||||
deserialize(item._nc, in);
|
||||
deserialize(item._stride_y, in);
|
||||
deserialize(item._stride_x, in);
|
||||
|
||||
item.ap.setup_avg_pooling(item._nr, item._nc, item._stride_y, item._stride_x);
|
||||
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
|
||||
|
||||
long nr;
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
|
||||
deserialize(nr, in);
|
||||
deserialize(nc, in);
|
||||
deserialize(stride_y, in);
|
||||
deserialize(stride_x, in);
|
||||
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
|
||||
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
|
||||
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
|
||||
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
long _nr;
|
||||
long _nc;
|
||||
int _stride_y;
|
||||
int _stride_x;
|
||||
|
||||
tt::pooling ap;
|
||||
resizable_tensor params;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using avg_pool = add_layer<avg_pool_, SUBNET>;
|
||||
template <
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -425,16 +416,16 @@ namespace dlib
|
|||
FC_MODE = 1
|
||||
};
|
||||
|
||||
template <
|
||||
layer_mode mode
|
||||
>
|
||||
class bn_
|
||||
{
|
||||
public:
|
||||
bn_() : num_updates(0), running_stats_window_size(1000), mode(FC_MODE)
|
||||
bn_() : num_updates(0), running_stats_window_size(1000)
|
||||
{}
|
||||
|
||||
explicit bn_(layer_mode mode_) : num_updates(0), running_stats_window_size(1000), mode(mode_)
|
||||
{}
|
||||
|
||||
bn_(layer_mode mode_, unsigned long window_size) : num_updates(0), running_stats_window_size(window_size), mode(mode_)
|
||||
explicit bn_(unsigned long window_size) : num_updates(0), running_stats_window_size(window_size)
|
||||
{}
|
||||
|
||||
layer_mode get_mode() const { return mode; }
|
||||
|
@ -519,7 +510,7 @@ namespace dlib
|
|||
serialize(item.running_invstds, out);
|
||||
serialize(item.num_updates, out);
|
||||
serialize(item.running_stats_window_size, out);
|
||||
serialize((int)item.mode, out);
|
||||
serialize((int)mode, out);
|
||||
}
|
||||
|
||||
friend void deserialize(bn_& item, std::istream& in)
|
||||
|
@ -537,13 +528,14 @@ namespace dlib
|
|||
deserialize(item.running_invstds, in);
|
||||
deserialize(item.num_updates, in);
|
||||
deserialize(item.running_stats_window_size, in);
|
||||
int mode;
|
||||
deserialize(mode, in);
|
||||
item.mode = (layer_mode)mode;
|
||||
int _mode;
|
||||
deserialize(_mode, in);
|
||||
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
template < layer_mode Mode >
|
||||
friend class affine_;
|
||||
|
||||
resizable_tensor params;
|
||||
|
@ -552,32 +544,41 @@ namespace dlib
|
|||
resizable_tensor invstds, running_invstds;
|
||||
unsigned long num_updates;
|
||||
unsigned long running_stats_window_size;
|
||||
layer_mode mode;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using bn = add_layer<bn_, SUBNET>;
|
||||
using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
|
||||
template <typename SUBNET>
|
||||
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
enum fc_bias_mode{
|
||||
enum fc_bias_mode
|
||||
{
|
||||
FC_HAS_BIAS = 0,
|
||||
FC_NO_BIAS = 1
|
||||
};
|
||||
|
||||
struct num_fc_outputs
|
||||
{
|
||||
num_fc_outputs(unsigned long n) : num_outputs(n) {}
|
||||
unsigned long num_outputs;
|
||||
};
|
||||
|
||||
template <
|
||||
unsigned long num_outputs_,
|
||||
fc_bias_mode bias_mode
|
||||
>
|
||||
class fc_
|
||||
{
|
||||
static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");
|
||||
|
||||
public:
|
||||
fc_() : num_outputs(1), num_inputs(0), bias_mode(FC_HAS_BIAS)
|
||||
fc_() : num_outputs(num_outputs_), num_inputs(0)
|
||||
{
|
||||
}
|
||||
|
||||
explicit fc_(
|
||||
unsigned long num_outputs_,
|
||||
fc_bias_mode mode = FC_HAS_BIAS
|
||||
) : num_outputs(num_outputs_), num_inputs(0), bias_mode(mode)
|
||||
{
|
||||
}
|
||||
fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0) {}
|
||||
|
||||
unsigned long get_num_outputs (
|
||||
) const { return num_outputs; }
|
||||
|
@ -651,7 +652,7 @@ namespace dlib
|
|||
serialize(item.params, out);
|
||||
serialize(item.weights, out);
|
||||
serialize(item.biases, out);
|
||||
serialize((int)item.bias_mode, out);
|
||||
serialize((int)bias_mode, out);
|
||||
}
|
||||
|
||||
friend void deserialize(fc_& item, std::istream& in)
|
||||
|
@ -660,6 +661,7 @@ namespace dlib
|
|||
deserialize(version, in);
|
||||
if (version != "fc_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
|
||||
|
||||
deserialize(item.num_outputs, in);
|
||||
deserialize(item.num_inputs, in);
|
||||
deserialize(item.params, in);
|
||||
|
@ -667,7 +669,7 @@ namespace dlib
|
|||
deserialize(item.biases, in);
|
||||
int bmode = 0;
|
||||
deserialize(bmode, in);
|
||||
item.bias_mode = (fc_bias_mode)bmode;
|
||||
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -676,11 +678,14 @@ namespace dlib
|
|||
unsigned long num_inputs;
|
||||
resizable_tensor params;
|
||||
alias_tensor weights, biases;
|
||||
fc_bias_mode bias_mode;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using fc = add_layer<fc_, SUBNET>;
|
||||
template <
|
||||
unsigned long num_outputs,
|
||||
fc_bias_mode bias_mode,
|
||||
typename SUBNET
|
||||
>
|
||||
using fc = add_layer<fc_<num_outputs,bias_mode>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -849,27 +854,22 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
layer_mode mode
|
||||
>
|
||||
class affine_
|
||||
{
|
||||
public:
|
||||
affine_(
|
||||
) : mode(FC_MODE)
|
||||
{
|
||||
}
|
||||
|
||||
explicit affine_(
|
||||
layer_mode mode_
|
||||
) : mode(mode_)
|
||||
{
|
||||
}
|
||||
)
|
||||
{}
|
||||
|
||||
affine_(
|
||||
const bn_& item
|
||||
const bn_<mode>& item
|
||||
)
|
||||
{
|
||||
gamma = item.gamma;
|
||||
beta = item.beta;
|
||||
mode = item.mode;
|
||||
|
||||
params.copy_size(item.params);
|
||||
|
||||
|
@ -959,7 +959,7 @@ namespace dlib
|
|||
// Since we can build an affine_ from a bn_ we check if that's what is in
|
||||
// the stream and if so then just convert it right here.
|
||||
unserialize sin(version, in);
|
||||
bn_ temp;
|
||||
bn_<mode> temp;
|
||||
deserialize(temp, sin);
|
||||
item = temp;
|
||||
return;
|
||||
|
@ -970,19 +970,20 @@ namespace dlib
|
|||
deserialize(item.params, in);
|
||||
deserialize(item.gamma, in);
|
||||
deserialize(item.beta, in);
|
||||
int mode;
|
||||
deserialize(mode, in);
|
||||
item.mode = (layer_mode)mode;
|
||||
int _mode;
|
||||
deserialize(_mode, in);
|
||||
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::affine_");
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params, empty_params;
|
||||
alias_tensor gamma, beta;
|
||||
layer_mode mode;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using affine = add_layer<affine_, SUBNET>;
|
||||
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
|
||||
template <typename SUBNET>
|
||||
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -1129,6 +1130,9 @@ namespace dlib
|
|||
{
|
||||
}
|
||||
|
||||
float get_initial_param_value (
|
||||
) const { return initial_param_value; }
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& /*sub*/)
|
||||
{
|
||||
|
|
|
@ -322,14 +322,28 @@ namespace dlib
|
|||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
enum fc_bias_mode{
|
||||
enum fc_bias_mode
|
||||
{
|
||||
FC_HAS_BIAS = 0,
|
||||
FC_NO_BIAS = 1
|
||||
};
|
||||
|
||||
struct num_fc_outputs
|
||||
{
|
||||
num_fc_outputs(unsigned long n) : num_outputs(n) {}
|
||||
unsigned long num_outputs;
|
||||
};
|
||||
|
||||
template <
|
||||
unsigned long num_outputs,
|
||||
fc_bias_mode bias_mode
|
||||
>
|
||||
class fc_
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON num_outputs
|
||||
num_outputs > 0
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
|
||||
In particular, it defines a fully connected layer that takes an input
|
||||
|
@ -337,24 +351,13 @@ namespace dlib
|
|||
!*/
|
||||
|
||||
public:
|
||||
|
||||
fc_(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_num_outputs() == 1
|
||||
- #get_bias_mode() == FC_HAS_BIAS
|
||||
!*/
|
||||
|
||||
explicit fc_(
|
||||
unsigned long num_outputs,
|
||||
fc_bias_mode mode = FC_HAS_BIAS
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- num_outputs > 0
|
||||
ensures
|
||||
- #get_num_outputs() == num_outputs
|
||||
- #get_bias_mode() == mode
|
||||
- #get_bias_mode() == bias_mode
|
||||
!*/
|
||||
|
||||
unsigned long get_num_outputs (
|
||||
|
@ -385,22 +388,37 @@ namespace dlib
|
|||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
|
||||
!*/
|
||||
|
||||
friend void serialize(const fc_& item, std::ostream& out);
|
||||
friend void deserialize(fc_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
};
|
||||
|
||||
void serialize(const fc_& item, std::ostream& out);
|
||||
void deserialize(fc_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using fc = add_layer<fc_, SUBNET>;
|
||||
template <
|
||||
unsigned long num_outputs,
|
||||
fc_bias_mode bias_mode,
|
||||
typename SUBNET
|
||||
>
|
||||
using fc = add_layer<fc_<num_outputs,bias_mode>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _num_filters,
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class con_
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON TEMPLATE ARGUMENTS
|
||||
All of them must be > 0.
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
|
||||
In particular, it defines a convolution layer that takes an input tensor
|
||||
|
@ -420,33 +438,11 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
ensures
|
||||
- #num_filters() == 1
|
||||
- #nr() == 3
|
||||
- #nc() == 3
|
||||
- #stride_y() == 1
|
||||
- #stride_x() == 1
|
||||
!*/
|
||||
|
||||
con_(
|
||||
long num_filters_,
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- num_filters_ > 0
|
||||
- nr_ > 0
|
||||
- nc_ > 0
|
||||
- stride_y_ > 0
|
||||
- stride_x_ > 0
|
||||
ensures
|
||||
- #num_filters() == num_filters_
|
||||
- #nr() == nr_
|
||||
- #nc() == nc_
|
||||
- #stride_y() == stride_y_
|
||||
- #stride_x() == stride_x_
|
||||
- #num_filters() == _num_filters
|
||||
- #nr() == _nr
|
||||
- #nc() == _nc
|
||||
- #stride_y() == _stride_y
|
||||
- #stride_x() == _stride_x
|
||||
!*/
|
||||
|
||||
long num_filters(
|
||||
|
@ -498,16 +494,24 @@ namespace dlib
|
|||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
|
||||
!*/
|
||||
|
||||
friend void serialize(const con_& item, std::ostream& out);
|
||||
friend void deserialize(con_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
void serialize(const con_& item, std::ostream& out);
|
||||
void deserialize(con_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using con = add_layer<con_, SUBNET>;
|
||||
template <
|
||||
long num_filters,
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -631,6 +635,9 @@ namespace dlib
|
|||
FC_MODE = 1 // fully connected mode
|
||||
};
|
||||
|
||||
template <
|
||||
layer_mode mode
|
||||
>
|
||||
class bn_
|
||||
{
|
||||
/*!
|
||||
|
@ -663,17 +670,17 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_mode() == FC_MODE
|
||||
- #get_mode() == mode
|
||||
- get_running_stats_window_size() == 1000
|
||||
!*/
|
||||
|
||||
explicit bn_(
|
||||
layer_mode mode
|
||||
unsigned long window_size
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_mode() == mode
|
||||
- get_running_stats_window_size() == 1000
|
||||
- get_running_stats_window_size() == window_size
|
||||
!*/
|
||||
|
||||
layer_mode get_mode(
|
||||
|
@ -713,19 +720,25 @@ namespace dlib
|
|||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
|
||||
!*/
|
||||
|
||||
friend void serialize(const bn_& item, std::ostream& out);
|
||||
friend void deserialize(bn_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
void serialize(const bn_& item, std::ostream& out);
|
||||
void deserialize(bn_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using bn = add_layer<bn_, SUBNET>;
|
||||
using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
|
||||
template <typename SUBNET>
|
||||
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
layer_mode mode
|
||||
>
|
||||
class affine_
|
||||
{
|
||||
/*!
|
||||
|
@ -766,11 +779,11 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_mode() == FC_MODE
|
||||
- #get_mode() == mode
|
||||
!*/
|
||||
|
||||
affine_(
|
||||
const bn_& layer
|
||||
const bn_<mode>& layer
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
|
@ -781,14 +794,6 @@ namespace dlib
|
|||
- #get_mode() == layer.get_mode()
|
||||
!*/
|
||||
|
||||
explicit affine_(
|
||||
layer_mode mode
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_mode() == mode
|
||||
!*/
|
||||
|
||||
layer_mode get_mode(
|
||||
) const;
|
||||
/*!
|
||||
|
@ -806,22 +811,33 @@ namespace dlib
|
|||
Also note that get_layer_params() always returns an empty tensor since there
|
||||
are no learnable parameters in this object.
|
||||
!*/
|
||||
|
||||
friend void serialize(const affine_& item, std::ostream& out);
|
||||
friend void deserialize(affine_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
};
|
||||
|
||||
void serialize(const affine_& item, std::ostream& out);
|
||||
void deserialize(affine_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using affine = add_layer<affine_, SUBNET>;
|
||||
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
|
||||
template <typename SUBNET>
|
||||
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class max_pool_
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON TEMPLATE ARGUMENTS
|
||||
All of them must be > 0.
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
|
||||
In particular, it defines a max pooling layer that takes an input tensor
|
||||
|
@ -849,24 +865,10 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
ensures
|
||||
- #nr() == 3
|
||||
- #nc() == 3
|
||||
- #stride_y() == 1
|
||||
- #stride_x() == 1
|
||||
!*/
|
||||
|
||||
max_pool_(
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #nr() == nr_
|
||||
- #nc() == nc_
|
||||
- #stride_y() == stride_y_
|
||||
- #stride_x() == stride_x_
|
||||
- #nr() == _nr
|
||||
- #nc() == _nc
|
||||
- #stride_y() == _stride_y
|
||||
- #stride_x() == _stride_x
|
||||
!*/
|
||||
|
||||
long nr(
|
||||
|
@ -911,22 +913,37 @@ namespace dlib
|
|||
Note that this layer doesn't have any parameters, so the tensor returned by
|
||||
get_layer_params() is always empty.
|
||||
!*/
|
||||
|
||||
friend void serialize(const max_pool_& item, std::ostream& out);
|
||||
friend void deserialize(max_pool_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
};
|
||||
|
||||
void serialize(const max_pool_& item, std::ostream& out);
|
||||
void deserialize(max_pool_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using max_pool = add_layer<max_pool_, SUBNET>;
|
||||
template <
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
long _nr,
|
||||
long _nc,
|
||||
int _stride_y,
|
||||
int _stride_x
|
||||
>
|
||||
class avg_pool_
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON TEMPLATE ARGUMENTS
|
||||
All of them must be > 0.
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
|
||||
In particular, it defines an average pooling layer that takes an input tensor
|
||||
|
@ -954,24 +971,10 @@ namespace dlib
|
|||
);
|
||||
/*!
|
||||
ensures
|
||||
- #nr() == 3
|
||||
- #nc() == 3
|
||||
- #stride_y() == 1
|
||||
- #stride_x() == 1
|
||||
!*/
|
||||
|
||||
avg_pool_(
|
||||
long nr_,
|
||||
long nc_,
|
||||
int stride_y_ = 1,
|
||||
int stride_x_ = 1
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #nr() == nr_
|
||||
- #nc() == nc_
|
||||
- #stride_y() == stride_y_
|
||||
- #stride_x() == stride_x_
|
||||
- #nr() == _nr
|
||||
- #nc() == _nc
|
||||
- #stride_y() == _stride_y
|
||||
- #stride_x() == _stride_x
|
||||
!*/
|
||||
|
||||
long nr(
|
||||
|
@ -1016,16 +1019,22 @@ namespace dlib
|
|||
Note that this layer doesn't have any parameters, so the tensor returned by
|
||||
get_layer_params() is always empty.
|
||||
!*/
|
||||
|
||||
friend void serialize(const avg_pool_& item, std::ostream& out);
|
||||
friend void deserialize(avg_pool_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
};
|
||||
|
||||
void serialize(const avg_pool_& item, std::ostream& out);
|
||||
void deserialize(avg_pool_& item, std::istream& in);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
using avg_pool = add_layer<avg_pool_, SUBNET>;
|
||||
template <
|
||||
long nr,
|
||||
long nc,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
typename SUBNET
|
||||
>
|
||||
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -1094,6 +1103,14 @@ namespace dlib
|
|||
/*!
|
||||
ensures
|
||||
- The p parameter will be initialized with initial_param_value.
|
||||
- #get_initial_param_value() == initial_param_value.
|
||||
!*/
|
||||
|
||||
float get_initial_param_value (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the initial value of the prelu parameter.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
|
|
|
@ -1076,67 +1076,67 @@ namespace
|
|||
}
|
||||
{
|
||||
print_spinner();
|
||||
max_pool_ l;
|
||||
max_pool_<3,3,1,1> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
avg_pool_ l;
|
||||
avg_pool_<3,3,1,1> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
affine_ l(CONV_MODE);
|
||||
affine_<CONV_MODE> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
affine_ l(FC_MODE);
|
||||
affine_<FC_MODE> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
bn_ l(CONV_MODE);
|
||||
bn_<CONV_MODE> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
bn_ l(FC_MODE);
|
||||
bn_<FC_MODE> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
con_ l(3,3,3,2,2);
|
||||
con_<3,3,3,2,2> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
con_ l(3,3,3,1,1);
|
||||
con_<3,3,3,1,1>l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
con_ l(3,3,2,1,1);
|
||||
con_<3,3,2,1,1> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
con_ l(2,1,1,1,1);
|
||||
con_<2,1,1,1,1> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
fc_ l;
|
||||
fc_<1,FC_HAS_BIAS> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
fc_ l(5,FC_HAS_BIAS);
|
||||
fc_<5,FC_HAS_BIAS> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
fc_ l(5,FC_NO_BIAS);
|
||||
fc_<5,FC_NO_BIAS> l;
|
||||
DLIB_TEST_MSG(test_layer(l), test_layer(l));
|
||||
}
|
||||
{
|
||||
|
@ -1168,29 +1168,16 @@ namespace
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T> using rcon = max_pool<relu<bn<con<T>>>>;
|
||||
std::tuple<max_pool_,relu_,bn_,con_> rcon_ (unsigned long n)
|
||||
{
|
||||
return std::make_tuple(max_pool_(2,2,2,2),relu_(),bn_(CONV_MODE),con_(n,5,5));
|
||||
}
|
||||
|
||||
template <typename T> using rfc = relu<bn<fc<T>>>;
|
||||
std::tuple<relu_,bn_,fc_> rfc_ (unsigned long n)
|
||||
{
|
||||
return std::make_tuple(relu_(),bn_(),fc_(n));
|
||||
}
|
||||
template <unsigned long n, typename SUBNET> using rcon = max_pool<2,2,2,2,relu<bn_con<con<n,5,5,1,1,SUBNET>>>>;
|
||||
template <unsigned long n, typename SUBNET> using rfc = relu<bn_fc<fc<n,FC_HAS_BIAS,SUBNET>>>;
|
||||
|
||||
void test_tagging(
|
||||
)
|
||||
{
|
||||
typedef loss_multiclass_log<rfc<skip1<rfc<rfc<tag1<rcon<rcon<input<matrix<unsigned char>>>>>>>>>> net_type;
|
||||
typedef loss_multiclass_log<rfc<10,skip1<rfc<84,rfc<120,tag1<rcon<16,rcon<6,input<matrix<unsigned char>>>>>>>>>> net_type;
|
||||
|
||||
net_type net(rfc_(10),
|
||||
rfc_(84),
|
||||
rfc_(120),
|
||||
rcon_(16),
|
||||
rcon_(6)
|
||||
);
|
||||
net_type net;
|
||||
net_type net2(num_fc_outputs(4));
|
||||
|
||||
DLIB_TEST(layer<tag1>(net).num_layers == 8);
|
||||
DLIB_TEST(layer<skip1>(net).num_layers == 8+3+3);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
|
||||
int main(int argc, char** argv) try
|
||||
{
|
||||
if (argc != 2)
|
||||
|
@ -23,6 +24,8 @@ int main(int argc, char** argv) try
|
|||
return 1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<matrix<unsigned char>> training_images;
|
||||
std::vector<unsigned long> training_labels;
|
||||
std::vector<matrix<unsigned char>> testing_images;
|
||||
|
@ -30,22 +33,18 @@ int main(int argc, char** argv) try
|
|||
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
|
||||
|
||||
|
||||
typedef loss_multiclass_log<fc<relu<fc<relu<fc<max_pool<relu<con<max_pool<relu<con<
|
||||
input<matrix<unsigned char>>>>>>>>>>>>>> net_type;
|
||||
using net_type = loss_multiclass_log<
|
||||
fc<10,FC_HAS_BIAS,
|
||||
relu<fc<84,FC_HAS_BIAS,
|
||||
relu<fc<120,FC_HAS_BIAS,
|
||||
max_pool<2,2,2,2,relu<con<16,5,5,1,1,
|
||||
max_pool<2,2,2,2,relu<con<6,5,5,1,1,
|
||||
input<matrix<unsigned char>>>>>>>>>>>>>>;
|
||||
|
||||
net_type net(fc_(10),
|
||||
relu_(),
|
||||
fc_(84),
|
||||
relu_(),
|
||||
fc_(120),
|
||||
max_pool_(2,2,2,2),
|
||||
relu_(),
|
||||
con_(16,5,5),
|
||||
max_pool_(2,2,2,2),
|
||||
relu_(),
|
||||
con_(6,5,5));
|
||||
|
||||
dnn_trainer<net_type> trainer(net,sgd(0.1));
|
||||
net_type net;
|
||||
|
||||
dnn_trainer<net_type> trainer(net,sgd(0.01));
|
||||
trainer.set_mini_batch_size(128);
|
||||
trainer.be_verbose();
|
||||
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
|
||||
|
|
|
@ -9,23 +9,19 @@ using namespace dlib;
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T> using res = relu<add_prev1<bn<con<relu<bn<con<tag1<T>>>>>>>>;
|
||||
template <int stride, typename SUBNET>
|
||||
using base_res = relu<add_prev1< bn_con<con<8,3,3,1,1,relu< bn_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
|
||||
|
||||
std::tuple<relu_,add_prev1_,bn_,con_,relu_,bn_,con_> res_ (
|
||||
unsigned long outputs,
|
||||
unsigned long stride = 1
|
||||
)
|
||||
{
|
||||
return std::make_tuple(relu_(),
|
||||
add_prev1_(),
|
||||
bn_(CONV_MODE),
|
||||
con_(outputs,3,3,stride,stride),
|
||||
relu_(),
|
||||
bn_(CONV_MODE),
|
||||
con_(outputs,3,3,stride,stride));
|
||||
}
|
||||
template <int stride, typename SUBNET>
|
||||
using base_ares = relu<add_prev1<affine_con<con<8,3,3,1,1,relu<affine_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
|
||||
|
||||
template <typename T> using ares = relu<add_prev1<affine<con<relu<affine<con<tag1<T>>>>>>>>;
|
||||
template <typename SUBNET> using res = base_res<1,SUBNET>;
|
||||
template <typename SUBNET> using res_down = base_res<2,SUBNET>;
|
||||
template <typename SUBNET> using ares = base_ares<1,SUBNET>;
|
||||
template <typename SUBNET> using ares_down = base_ares<2,SUBNET>;
|
||||
|
||||
template <typename SUBNET>
|
||||
using pres = prelu<add_prev1< bn_con<con<8,3,3,1,1,prelu< bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -44,24 +40,78 @@ int main(int argc, char** argv) try
|
|||
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
|
||||
|
||||
|
||||
set_dnn_prefer_smallest_algorithms();
|
||||
|
||||
typedef loss_multiclass_log<fc<avg_pool<
|
||||
res<res<res<res<
|
||||
repeat<10,res,
|
||||
res<
|
||||
const unsigned long number_of_classes = 10;
|
||||
typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
|
||||
avg_pool<11,11,11,11,
|
||||
res<res<res<res_down<
|
||||
repeat<9,res, // repeat this layer 9 times
|
||||
res_down<
|
||||
res<
|
||||
input<matrix<unsigned char>
|
||||
>>>>>>>>>>> net_type;
|
||||
|
||||
|
||||
const unsigned long number_of_classes = 10;
|
||||
net_type net(fc_(number_of_classes),
|
||||
avg_pool_(10,10,10,10),
|
||||
res_(8),res_(8),res_(8),res_(8,2),
|
||||
res_(8), // repeated 10 times
|
||||
res_(8,2),
|
||||
res_(8)
|
||||
);
|
||||
net_type net;
|
||||
|
||||
|
||||
// If you wanted to use the same network but override the number of outputs at runtime
|
||||
// you can do so like this:
|
||||
net_type net2(num_fc_outputs(15));
|
||||
|
||||
// Let's imagine we wanted to replace some of the relu layers with prelu layers. We
|
||||
// might do it like this:
|
||||
typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
|
||||
avg_pool<11,11,11,11,
|
||||
pres<res<res<res_down< // 2 prelu layers here
|
||||
tag4<repeat<9,pres, // 9 groups, each containing 2 prelu layers
|
||||
res_down<
|
||||
res<
|
||||
input<matrix<unsigned char>
|
||||
>>>>>>>>>>>> net_type2;
|
||||
|
||||
// prelu layers have a floating point parameter. If you want to set it to something
|
||||
// other than its default value you can do so like this:
|
||||
net_type2 pnet(prelu_(0.2),
|
||||
prelu_(0.2),
|
||||
repeat_group(prelu_(0.3),prelu_(0.4)) // Initialize all the prelu instances in the repeat
|
||||
// layer. repeat_group() is needed to group the things
|
||||
// that are part of repeat's block.
|
||||
);
|
||||
// As you can see, a network will greedily assign things given to its constructor to
|
||||
// the layers inside itself. The assignment is done in the order the layers are
|
||||
// defined but it will skip layers where the assignment doesn't make sense.
|
||||
|
||||
|
||||
// You can access sub layers of the network like this:
|
||||
net.subnet().subnet().get_output();
|
||||
layer<2>(net).get_output();
|
||||
layer<relu>(net).get_output();
|
||||
layer<tag1>(net).get_output();
|
||||
// To further illustrate the use of layer(), let's loop over the repeated layers and
|
||||
// print out their parameters. But first, let's grab a reference to the repeat layer.
|
||||
// Since we tagged the repeat layer we can access it using the layer() method.
|
||||
// layer<tag4>(pnet) returns the tag4 layer, but we want the repeat layer so we can
|
||||
// give an integer as the second argument and it will jump that many layers down the
|
||||
// network. In our case we need to jump just 1 layer down to get to repeat.
|
||||
auto&& repeat_layer = layer<tag4,1>(pnet);
|
||||
for (size_t i = 0; i < repeat_layer.num_repetitions(); ++i)
|
||||
{
|
||||
// The repeat layer just instantiates the network block a bunch of times as a
|
||||
// network object. get_repeated_layer() allows us to grab each of these instances.
|
||||
auto&& repeated_layer = repeat_layer.get_repeated_layer(i);
|
||||
// Now that we have the i-th layer inside our repeat layer we can look at its
|
||||
// properties. Recall that we repeated the "pres" network block, which is itself a
|
||||
// network with a bunch of layers. So we can again use layer() to jump to the
|
||||
// prelu layers we are interested in like so:
|
||||
prelu_ prelu1 = layer<prelu>(repeated_layer).layer_details();
|
||||
prelu_ prelu2 = layer<prelu>(repeated_layer.subnet()).layer_details();
|
||||
cout << "first prelu layer parameter value: "<< prelu1.get_initial_param_value() << endl;;
|
||||
cout << "second prelu layer parameter value: "<< prelu2.get_initial_param_value() << endl;;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
dnn_trainer<net_type,adam> trainer(net,adam(0.001));
|
||||
|
@ -89,20 +139,16 @@ int main(int argc, char** argv) try
|
|||
// wait for threaded processing to stop.
|
||||
trainer.get_net();
|
||||
|
||||
// You can access sub layers of the network like this:
|
||||
net.subnet().subnet().get_output();
|
||||
layer<2>(net).get_output();
|
||||
layer<avg_pool>(net).get_output();
|
||||
|
||||
net.clean();
|
||||
serialize("mnist_res_network.dat") << net;
|
||||
|
||||
|
||||
|
||||
typedef loss_multiclass_log<fc<avg_pool<
|
||||
ares<ares<ares<ares<
|
||||
repeat<10,ares,
|
||||
ares<
|
||||
typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
|
||||
avg_pool<11,11,11,11,
|
||||
ares<ares<ares<ares_down<
|
||||
repeat<9,res,
|
||||
ares_down<
|
||||
ares<
|
||||
input<matrix<unsigned char>
|
||||
>>>>>>>>>>> test_net_type;
|
||||
|
|
Loading…
Reference in New Issue