mirror of https://github.com/davisking/dlib.git
Added affine_ layer.
This commit is contained in:
parent
8062663c78
commit
2f34414e49
|
@ -423,6 +423,87 @@ namespace dlib
|
|||
template <typename SUBNET>
|
||||
using dropout = add_layer<dropout_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class affine_
|
||||
{
|
||||
public:
|
||||
affine_(
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
gamma = alias_tensor(1,
|
||||
sub.get_output().k(),
|
||||
sub.get_output().nr(),
|
||||
sub.get_output().nc());
|
||||
beta = gamma;
|
||||
|
||||
params.set_size(gamma.size()+beta.size());
|
||||
|
||||
gamma(params,0) = 1;
|
||||
beta(params,gamma.size()) = 0;
|
||||
}
|
||||
|
||||
void forward_inplace(const tensor& input, tensor& output)
|
||||
{
|
||||
auto g = gamma(params,0);
|
||||
auto b = beta(params,gamma.size());
|
||||
tt::affine_transform(output, input, g, b);
|
||||
}
|
||||
|
||||
void backward_inplace(
|
||||
const tensor& computed_output,
|
||||
const tensor& gradient_input,
|
||||
tensor& data_grad,
|
||||
tensor& params_grad
|
||||
)
|
||||
{
|
||||
auto g = gamma(params,0);
|
||||
auto b = beta(params,gamma.size());
|
||||
auto g_grad = gamma(params_grad,0);
|
||||
auto b_grad = beta(params_grad,gamma.size());
|
||||
|
||||
// We are computing the gradient of dot(gradient_input, computed_output*g + b)
|
||||
tt::multiply(data_grad, gradient_input, g);
|
||||
|
||||
tt::multiply(g_grad, gradient_input, computed_output);
|
||||
tt::add_bias_gradient(b_grad, gradient_input);
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const affine_& item, std::ostream& out)
|
||||
{
|
||||
serialize("affine_", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.gamma, out);
|
||||
serialize(item.beta, out);
|
||||
}
|
||||
|
||||
friend void deserialize(affine_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "affine_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::affine_.");
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.gamma, in);
|
||||
deserialize(item.beta, in);
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params;
|
||||
alias_tensor gamma, beta;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using affine = add_layer<affine_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class relu_
|
||||
|
|
Loading…
Reference in New Issue