Add tril_ layer for lower triangular matrix operations (#3018)

* Add tril_ layer for lower triangular matrix operations

* Improved layer consistency

* Added constant_wrapper to fix the issue of the float in the template in c++17

* Looking for a solution for c++ 14

* Refactor tril_ layer for improved flexibility and C++14 compatibility

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates
This commit is contained in:
Cydral 2024-09-30 03:59:21 +02:00 committed by GitHub
parent 72822fe4f9
commit 4e53f83160
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 346 additions and 0 deletions

View File

@ -4696,6 +4696,132 @@ namespace dlib
template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;
// ----------------------------------------------------------------------------------------
struct neg_infinity_tag {};
struct zero_tag {};
template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
public:
tril_(): diag(diag_), diag_value(compute_diag_value()) {}
template <typename SUBNET>
void setup(const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());
check_mask(prev);
tt::multiply(false, output, prev, binary_mask);
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev_grad = sub.get_gradient_input();
tt::multiply(true, prev_grad, gradient_input, binary_mask);
}
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const tril_& item, std::ostream& out)
{
serialize("tril_", out);
serialize(item.diag, out);
serialize(item.diag_value, out);
}
friend void deserialize(tril_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "tril_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
deserialize(item.diag, in);
deserialize(item.diag_value, in);
}
friend std::ostream& operator<<(std::ostream& out, const tril_& item)
{
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
return out;
}
friend void to_xml(const tril_& item, std::ostream& out)
{
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
}
private:
float compute_diag_value() const {
if (std::is_same<tag_, neg_infinity_tag>::value)
return -std::numeric_limits<float>::infinity();
else if (std::is_same<tag_, zero_tag>::value)
return 0.0f;
else
return static_cast<float>(num_) / static_cast<float>(den_);
}
void check_mask(const tensor& t)
{
if (!have_same_dimensions(binary_mask, t)) {
binary_mask.copy_size(t);
binary_mask = 1;
if (diag_value != 0.0f) {
output_mask.copy_size(t);
output_mask = 0;
}
for (long s = 0; s < output_mask.num_samples(); ++s)
{
for (long k = 0; k < output_mask.k(); ++k)
{
for (long r = 0; r < output_mask.nr(); ++r)
{
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
{
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
}
}
}
}
}
}
template <typename T>
struct always_false : std::false_type {};
resizable_tensor params; // unused
resizable_tensor binary_mask, output_mask;
long diag;
float diag_value;
};
template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
// ----------------------------------------------------------------------------------------
}

View File

@ -3711,6 +3711,162 @@ namespace dlib
template <typename SUBNET>
using transpose = add_layer<transpose_, SUBNET>;
// ----------------------------------------------------------------------------------------
struct neg_infinity_tag {};
struct zero_tag {};
template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
/*!
TEMPLATE PARAMETERS
- diag_: A long integer specifying the diagonal offset.
- tag_: A type tag specifying special values or void for numeric values.
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
REQUIREMENTS
- diag_ must be an integer.
- tag_ must be either neg_infinity_tag, zero_tag, or void.
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
WHAT THIS OBJECT REPRESENTS
This object implements a layer in a deep neural network that applies a lower triangular mask to
its input tensor. The mask is defined such that all elements above the specified diagonal are set
to a given value. The diagonal offset and the mask value are determined by the template parameters.
DIAGONAL VALUE DETERMINATION
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
- If tag_ is zero_tag: diagonal value is set to zero.
- If tag_ is void: diagonal value is set to num_ / den_ as a float.
DIAGONAL OFFSET
The diag_ parameter determines the diagonal above which elements are masked:
- diag_ = 0: main diagonal
- diag_ > 0: diag_ steps above the main diagonal
- diag_ < 0: |diag_| steps below the main diagonal
EXAMPLE USAGE
// Create a layer that masks all elements above the main diagonal with -inf
tril_<0, neg_infinity_tag> layer1;
// Create a layer that masks all elements above the main diagonal with 0
tril_<0, zero_tag> layer2;
// Create a layer that masks all elements above the main diagonal with 0.5
tril_<0, void, 1, 2> layer3;
// Create a layer that masks all elements 5 positions above the main diagonal with -inf
tril_<5, neg_infinity_tag> layer4;
// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
tril_<-3, void, 1, 4> layer5;
SERIALIZATION SUPPORT
This object supports serialization and deserialization via the serialize() and deserialize() functions.
!*/
public:
tril_() = default;
/*!
ensures
- This object is properly initialized.
!*/
template <typename SUBNET>
void setup(const SUBNET& sub);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Initializes the mask based on the dimensions of the input tensor from sub.
!*/
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
!*/
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
!*/
inline dpoint map_input_to_output(const dpoint& p) const;
/*!
ensures
- Maps a point from the input tensor to the corresponding point in the output tensor.
!*/
inline dpoint map_output_to_input(const dpoint& p) const;
/*!
ensures
- Maps a point from the output tensor to the corresponding point in the input tensor.
!*/
const tensor& get_layer_params() const;
/*!
ensures
- Returns the parameters of this layer.
!*/
tensor& get_layer_params();
/*!
ensures
- Returns the parameters of this layer.
!*/
friend void serialize(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to the given output stream.
!*/
friend void deserialize(tril_& item, std::istream& in);
/*!
ensures
- Deserializes the state of this object from the given input stream.
!*/
friend std::ostream& operator<<(std::ostream& out, const tril_& item);
/*!
ensures
- Prints a human-readable representation of this object to the given output stream.
!*/
friend void to_xml(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to XML format and writes it to the given output stream.
!*/
};
template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
// ----------------------------------------------------------------------------------------
}

View File

@ -1029,6 +1029,22 @@ namespace dlib
update(i);
}
template <long diag, typename tag, long num, long den, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
{
start_node(i, "tril");
out << " | {diag|{" << diag << "}}";
out << " | {diag_value|{";
if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
else if (std::is_same<tag, zero_tag>::value) out << "0";
else out << static_cast<float>(num) / static_cast<float>(den);
out << "}}";
end_node();
update(i);
}
template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
{

View File

@ -2023,6 +2023,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
tril_<-5, void, 1, 2> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
extract_<0,2,2,2> l;
@ -4447,6 +4453,47 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_tril()
{
print_spinner();
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
net_type net;
// Input tensor
dlib::rand rnd;
const int nr = 2, nc = 3;
constexpr int n_samples = 3, k = 1;
std::vector<matrix<float>> x(n_samples);
matrix<float> xtmp(nr, nc);
for (int ii = 0; ii < n_samples; ++ii) {
for (int jj = 0; jj < nr; ++jj)
for (int kk = 0; kk < nc; ++kk)
xtmp(jj, kk) = rnd.get_random_gaussian();
x[ii] = xtmp;
}
// Convert input matrix to tensor
resizable_tensor input_tensor;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);
// Expected output tensor (manually set for comparison)
resizable_tensor expected_output;
expected_output.copy_size(input_tensor);
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
for (int ii = 0; ii < n_samples; ++ii) {
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
}
// Compare output tensor with expected output
auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
}
// ----------------------------------------------------------------------------------------
class dnn_tester : public tester
@ -4527,6 +4574,7 @@ namespace
test_layer_normalize();
test_rms_normalize();
test_transpose();
test_tril();
test_basic_tensor_ops();
test_layers();
test_visit_functions();