Added affine_ layer.

This commit is contained in:
Davis King 2015-12-08 21:32:48 -05:00
parent 8062663c78
commit 2f34414e49
1 changed files with 81 additions and 0 deletions

View File

@ -423,6 +423,87 @@ namespace dlib
template <typename SUBNET>
using dropout = add_layer<dropout_, SUBNET>;
// ----------------------------------------------------------------------------------------
class affine_
{
public:
affine_(
)
{
}
template <typename SUBNET>
void setup (const SUBNET& sub)
{
gamma = alias_tensor(1,
sub.get_output().k(),
sub.get_output().nr(),
sub.get_output().nc());
beta = gamma;
params.set_size(gamma.size()+beta.size());
gamma(params,0) = 1;
beta(params,gamma.size()) = 0;
}
void forward_inplace(const tensor& input, tensor& output)
{
auto g = gamma(params,0);
auto b = beta(params,gamma.size());
tt::affine_transform(output, input, g, b);
}
void backward_inplace(
const tensor& computed_output,
const tensor& gradient_input,
tensor& data_grad,
tensor& params_grad
)
{
auto g = gamma(params,0);
auto b = beta(params,gamma.size());
auto g_grad = gamma(params_grad,0);
auto b_grad = beta(params_grad,gamma.size());
// We are computing the gradient of dot(gradient_input, computed_output*g + b)
tt::multiply(data_grad, gradient_input, g);
tt::multiply(g_grad, gradient_input, computed_output);
tt::add_bias_gradient(b_grad, gradient_input);
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const affine_& item, std::ostream& out)
{
serialize("affine_", out);
serialize(item.params, out);
serialize(item.gamma, out);
serialize(item.beta, out);
}
friend void deserialize(affine_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "affine_")
throw serialization_error("Unexpected version found while deserializing dlib::affine_.");
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.beta, in);
}
private:
resizable_tensor params;
alias_tensor gamma, beta;
};
template <typename SUBNET>
using affine = add_layer<affine_, SUBNET>;
// ----------------------------------------------------------------------------------------
class relu_