mirror of https://github.com/davisking/dlib.git
Add tril_ layer for lower triangular matrix operations (#3018)
* Add tril_ layer for lower triangular matrix operations * Improved layer consistency * Added constant_wrapper to fix the issue of the float in the template in c++17 * Looking for a solution for c++ 14 * Refactor tril_ layer for improved flexibility and C++14 compatibility * Updates * Updates * Updates * Updates * Updates * Updates
This commit is contained in:
parent
72822fe4f9
commit
4e53f83160
|
@ -4696,6 +4696,132 @@ namespace dlib
|
|||
|
||||
template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct neg_infinity_tag {};
|
||||
struct zero_tag {};
|
||||
|
||||
template<typename T>
|
||||
struct is_special_value : std::false_type {};
|
||||
template<>
|
||||
struct is_special_value<neg_infinity_tag> : std::true_type {};
|
||||
template<>
|
||||
struct is_special_value<zero_tag> : std::true_type {};
|
||||
|
||||
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
|
||||
class tril_
|
||||
{
|
||||
public:
|
||||
tril_(): diag(diag_), diag_value(compute_diag_value()) {}
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup(const SUBNET& /*sub*/)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
{
|
||||
auto& prev = sub.get_output();
|
||||
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());
|
||||
|
||||
check_mask(prev);
|
||||
tt::multiply(false, output, prev, binary_mask);
|
||||
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
|
||||
}
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
|
||||
{
|
||||
auto& prev_grad = sub.get_gradient_input();
|
||||
tt::multiply(true, prev_grad, gradient_input, binary_mask);
|
||||
}
|
||||
|
||||
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
|
||||
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const tril_& item, std::ostream& out)
|
||||
{
|
||||
serialize("tril_", out);
|
||||
serialize(item.diag, out);
|
||||
serialize(item.diag_value, out);
|
||||
}
|
||||
friend void deserialize(tril_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "tril_")
|
||||
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
|
||||
deserialize(item.diag, in);
|
||||
deserialize(item.diag_value, in);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const tril_& item)
|
||||
{
|
||||
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
|
||||
return out;
|
||||
}
|
||||
friend void to_xml(const tril_& item, std::ostream& out)
|
||||
{
|
||||
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
float compute_diag_value() const {
|
||||
if (std::is_same<tag_, neg_infinity_tag>::value)
|
||||
return -std::numeric_limits<float>::infinity();
|
||||
else if (std::is_same<tag_, zero_tag>::value)
|
||||
return 0.0f;
|
||||
else
|
||||
return static_cast<float>(num_) / static_cast<float>(den_);
|
||||
}
|
||||
|
||||
void check_mask(const tensor& t)
|
||||
{
|
||||
if (!have_same_dimensions(binary_mask, t)) {
|
||||
binary_mask.copy_size(t);
|
||||
binary_mask = 1;
|
||||
if (diag_value != 0.0f) {
|
||||
output_mask.copy_size(t);
|
||||
output_mask = 0;
|
||||
}
|
||||
for (long s = 0; s < output_mask.num_samples(); ++s)
|
||||
{
|
||||
for (long k = 0; k < output_mask.k(); ++k)
|
||||
{
|
||||
for (long r = 0; r < output_mask.nr(); ++r)
|
||||
{
|
||||
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
|
||||
{
|
||||
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
|
||||
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct always_false : std::false_type {};
|
||||
|
||||
resizable_tensor params; // unused
|
||||
resizable_tensor binary_mask, output_mask;
|
||||
long diag;
|
||||
float diag_value;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
|
||||
|
||||
template <typename SUBNET>
|
||||
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
|
||||
|
||||
template <long diag, long num, long den, typename SUBNET>
|
||||
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
|
|
@ -3711,6 +3711,162 @@ namespace dlib
|
|||
template <typename SUBNET>
|
||||
using transpose = add_layer<transpose_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct neg_infinity_tag {};
|
||||
struct zero_tag {};
|
||||
|
||||
template<typename T>
|
||||
struct is_special_value : std::false_type {};
|
||||
template<>
|
||||
struct is_special_value<neg_infinity_tag> : std::true_type {};
|
||||
template<>
|
||||
struct is_special_value<zero_tag> : std::true_type {};
|
||||
|
||||
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
|
||||
class tril_
|
||||
{
|
||||
/*!
|
||||
TEMPLATE PARAMETERS
|
||||
- diag_: A long integer specifying the diagonal offset.
|
||||
- tag_: A type tag specifying special values or void for numeric values.
|
||||
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
|
||||
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
|
||||
|
||||
REQUIREMENTS
|
||||
- diag_ must be an integer.
|
||||
- tag_ must be either neg_infinity_tag, zero_tag, or void.
|
||||
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
|
||||
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object implements a layer in a deep neural network that applies a lower triangular mask to
|
||||
its input tensor. The mask is defined such that all elements above the specified diagonal are set
|
||||
to a given value. The diagonal offset and the mask value are determined by the template parameters.
|
||||
|
||||
DIAGONAL VALUE DETERMINATION
|
||||
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
|
||||
- If tag_ is zero_tag: diagonal value is set to zero.
|
||||
- If tag_ is void: diagonal value is set to num_ / den_ as a float.
|
||||
|
||||
DIAGONAL OFFSET
|
||||
The diag_ parameter determines the diagonal above which elements are masked:
|
||||
- diag_ = 0: main diagonal
|
||||
- diag_ > 0: diag_ steps above the main diagonal
|
||||
- diag_ < 0: |diag_| steps below the main diagonal
|
||||
|
||||
EXAMPLE USAGE
|
||||
// Create a layer that masks all elements above the main diagonal with -inf
|
||||
tril_<0, neg_infinity_tag> layer1;
|
||||
|
||||
// Create a layer that masks all elements above the main diagonal with 0
|
||||
tril_<0, zero_tag> layer2;
|
||||
|
||||
// Create a layer that masks all elements above the main diagonal with 0.5
|
||||
tril_<0, void, 1, 2> layer3;
|
||||
|
||||
// Create a layer that masks all elements 5 positions above the main diagonal with -inf
|
||||
tril_<5, neg_infinity_tag> layer4;
|
||||
|
||||
// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
|
||||
tril_<-3, void, 1, 4> layer5;
|
||||
|
||||
SERIALIZATION SUPPORT
|
||||
This object supports serialization and deserialization via the serialize() and deserialize() functions.
|
||||
!*/
|
||||
|
||||
public:
|
||||
tril_() = default;
|
||||
/*!
|
||||
ensures
|
||||
- This object is properly initialized.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup(const SUBNET& sub);
|
||||
/*!
|
||||
requires
|
||||
- SUBNET is a valid network layer type.
|
||||
ensures
|
||||
- Initializes the mask based on the dimensions of the input tensor from sub.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
/*!
|
||||
requires
|
||||
- SUBNET is a valid network layer type.
|
||||
ensures
|
||||
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
/*!
|
||||
requires
|
||||
- SUBNET is a valid network layer type.
|
||||
ensures
|
||||
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
|
||||
!*/
|
||||
|
||||
inline dpoint map_input_to_output(const dpoint& p) const;
|
||||
/*!
|
||||
ensures
|
||||
- Maps a point from the input tensor to the corresponding point in the output tensor.
|
||||
!*/
|
||||
|
||||
inline dpoint map_output_to_input(const dpoint& p) const;
|
||||
/*!
|
||||
ensures
|
||||
- Maps a point from the output tensor to the corresponding point in the input tensor.
|
||||
!*/
|
||||
|
||||
const tensor& get_layer_params() const;
|
||||
/*!
|
||||
ensures
|
||||
- Returns the parameters of this layer.
|
||||
!*/
|
||||
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
ensures
|
||||
- Returns the parameters of this layer.
|
||||
!*/
|
||||
|
||||
friend void serialize(const tril_& item, std::ostream& out);
|
||||
/*!
|
||||
ensures
|
||||
- Serializes the state of this object to the given output stream.
|
||||
!*/
|
||||
|
||||
friend void deserialize(tril_& item, std::istream& in);
|
||||
/*!
|
||||
ensures
|
||||
- Deserializes the state of this object from the given input stream.
|
||||
!*/
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const tril_& item);
|
||||
/*!
|
||||
ensures
|
||||
- Prints a human-readable representation of this object to the given output stream.
|
||||
!*/
|
||||
|
||||
friend void to_xml(const tril_& item, std::ostream& out);
|
||||
/*!
|
||||
ensures
|
||||
- Serializes the state of this object to XML format and writes it to the given output stream.
|
||||
!*/
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
|
||||
|
||||
template <typename SUBNET>
|
||||
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
|
||||
|
||||
template <long diag, long num, long den, typename SUBNET>
|
||||
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
|
|
@ -1029,6 +1029,22 @@ namespace dlib
|
|||
update(i);
|
||||
}
|
||||
|
||||
template <long diag, typename tag, long num, long den, typename U, typename E>
|
||||
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
|
||||
{
|
||||
start_node(i, "tril");
|
||||
out << " | {diag|{" << diag << "}}";
|
||||
out << " | {diag_value|{";
|
||||
|
||||
if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
|
||||
else if (std::is_same<tag, zero_tag>::value) out << "0";
|
||||
else out << static_cast<float>(num) / static_cast<float>(den);
|
||||
|
||||
out << "}}";
|
||||
end_node();
|
||||
update(i);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename E>
|
||||
void operator()(size_t i, const add_layer<T, U, E>&)
|
||||
{
|
||||
|
|
|
@ -2023,6 +2023,12 @@ namespace
|
|||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
tril_<-5, void, 1, 2> l;
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
extract_<0,2,2,2> l;
|
||||
|
@ -4447,6 +4453,47 @@ namespace
|
|||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_tril()
|
||||
{
|
||||
print_spinner();
|
||||
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
|
||||
net_type net;
|
||||
|
||||
// Input tensor
|
||||
dlib::rand rnd;
|
||||
const int nr = 2, nc = 3;
|
||||
constexpr int n_samples = 3, k = 1;
|
||||
std::vector<matrix<float>> x(n_samples);
|
||||
matrix<float> xtmp(nr, nc);
|
||||
for (int ii = 0; ii < n_samples; ++ii) {
|
||||
for (int jj = 0; jj < nr; ++jj)
|
||||
for (int kk = 0; kk < nc; ++kk)
|
||||
xtmp(jj, kk) = rnd.get_random_gaussian();
|
||||
x[ii] = xtmp;
|
||||
}
|
||||
|
||||
// Convert input matrix to tensor
|
||||
resizable_tensor input_tensor;
|
||||
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
|
||||
net.forward(input_tensor);
|
||||
|
||||
// Expected output tensor (manually set for comparison)
|
||||
resizable_tensor expected_output;
|
||||
expected_output.copy_size(input_tensor);
|
||||
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
|
||||
for (int ii = 0; ii < n_samples; ++ii) {
|
||||
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
|
||||
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
|
||||
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
// Compare output tensor with expected output
|
||||
auto& net_output = layer<tag1>(net).get_output();
|
||||
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class dnn_tester : public tester
|
||||
|
@ -4527,6 +4574,7 @@ namespace
|
|||
test_layer_normalize();
|
||||
test_rms_normalize();
|
||||
test_transpose();
|
||||
test_tril();
|
||||
test_basic_tensor_ops();
|
||||
test_layers();
|
||||
test_visit_functions();
|
||||
|
|
Loading…
Reference in New Issue