From 2f34414e494ef364cd52927764cbc959cf048bdf Mon Sep 17 00:00:00 2001 From: Davis King Date: Tue, 8 Dec 2015 21:32:48 -0500 Subject: [PATCH] Added affine_ layer. --- dlib/dnn/layers.h | 81 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 2033c0d9c..39b6aa0d3 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -423,6 +423,87 @@ namespace dlib template using dropout = add_layer; +// ---------------------------------------------------------------------------------------- + + class affine_ + { + public: + affine_( + ) + { + } + + template + void setup (const SUBNET& sub) + { + gamma = alias_tensor(1, + sub.get_output().k(), + sub.get_output().nr(), + sub.get_output().nc()); + beta = gamma; + + params.set_size(gamma.size()+beta.size()); + + gamma(params,0) = 1; + beta(params,gamma.size()) = 0; + } + + void forward_inplace(const tensor& input, tensor& output) + { + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + tt::affine_transform(output, input, g, b); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& params_grad + ) + { + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + auto g_grad = gamma(params_grad,0); + auto b_grad = beta(params_grad,gamma.size()); + + // We are computing the gradient of dot(gradient_input, computed_output*g + b) + tt::multiply(data_grad, gradient_input, g); + + tt::multiply(g_grad, gradient_input, computed_output); + tt::add_bias_gradient(b_grad, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const affine_& item, std::ostream& out) + { + serialize("affine_", out); + serialize(item.params, out); + serialize(item.gamma, out); + serialize(item.beta, out); + } + + friend void deserialize(affine_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "affine_") + throw serialization_error("Unexpected version found while deserializing dlib::affine_."); + deserialize(item.params, in); + deserialize(item.gamma, in); + deserialize(item.beta, in); + } + + private: + resizable_tensor params; + alias_tensor gamma, beta; + }; + + template + using affine = add_layer; + // ---------------------------------------------------------------------------------------- class relu_