From 170877da88fdd887b9404aec78646e604d8368d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Wed, 28 Aug 2019 20:25:08 +0900 Subject: [PATCH] add loss_mean_squared_per_channel (#1863) add loss_mean_squared_per_channel_and_pixel --- dlib/dnn/loss.h | 144 ++++++++++++++++++++++++++++++++++++++- dlib/dnn/loss_abstract.h | 63 ++++++++++++++++- dlib/test/dnn.cpp | 61 +++++++++++++++++ 3 files changed, 266 insertions(+), 2 deletions(-) diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 5b5c53ce8..69fd8215e 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -2891,7 +2891,149 @@ namespace dlib // ---------------------------------------------------------------------------------------- - class loss_dot_ + template + class loss_mean_squared_per_channel_and_pixel_ + { + public: + + typedef std::array, _num_channels> training_label_type; + typedef std::array, _num_channels> output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.k() == _num_channels, "output k = " << output_tensor.k()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + + for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) + { + for (long k = 0; k < output_tensor.k(); ++k) + { + (*iter)[k].set_size(output_tensor.nr(), output_tensor.nc()); + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + (*iter)[k].operator()(r, c) = out_data[tensor_index(output_tensor, i, k, r, c)]; + } + } + } + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() == _num_channels); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT((*truth_matrix_ptr).size() == _num_channels); + for (long k = 0; k < output_tensor.k(); ++k) + { + DLIB_CASSERT((*truth_matrix_ptr)[k].nr() == output_tensor.nr() && + (*truth_matrix_ptr)[k].nc() == output_tensor.nc(), + "truth size = " << (*truth_matrix_ptr)[k].nr() << " x " << (*truth_matrix_ptr)[k].nc() << ", " + "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); + } + } + + // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.k() * output_tensor.nr() * output_tensor.nc()); + double loss = 0; + float* const g = grad.host(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) + { + for (long k = 0; k < output_tensor.k(); ++k) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + const float y = (*truth)[k].operator()(r, c); + const size_t idx = tensor_index(output_tensor, i, k, r, c); + const float temp1 = y - out_data[idx]; + const float temp2 = scale*temp1; + loss += temp2*temp1; + g[idx] = -temp2; + } + } + } + } + return loss; + } + + friend void serialize(const loss_mean_squared_per_channel_and_pixel_& , std::ostream& out) + { + serialize("loss_mean_squared_per_channel_and_pixel_", out); + } + + friend void deserialize(loss_mean_squared_per_channel_and_pixel_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_mean_squared_per_channel_and_pixel_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_channel_and_pixel_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_channel_and_pixel_& ) + { + out << "loss_mean_squared_per_channel_and_pixel"; + return out; + } + + friend void to_xml(const loss_mean_squared_per_channel_and_pixel_& /*item*/, std::ostream& out) + { + out << ""; + } + + private: + static size_t tensor_index(const tensor& t, long sample, long k, long row, long column) + { + // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 + return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; + } + }; + + template + using loss_mean_squared_per_channel_and_pixel = add_loss_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + class loss_dot_ { public: diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index 5a6b0c558..dd50fdba7 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -1495,7 +1495,68 @@ namespace dlib // ---------------------------------------------------------------------------------------- - class loss_dot_ + template + class loss_mean_squared_per_channel_and_pixel_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss, + which is appropriate for regression problems. It is basically just like + loss_mean_squared_per_pixel_ except that it computes the loss over all + channels, not just the first one. + !*/ + public: + + typedef std::array, _num_channels> training_label_type; + typedef std::array, _num_channels> output_label_type; + + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.get_output().k() == _num_channels + - sub.sample_expansion_factor() == 1 + and the output labels are the predicted continuous variables. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().k() == _num_channels + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - for all idx such that 0 <= idx < sub.get_output().num_samples(): + - sub.get_output().nr() == (*(truth + idx)).nr() + - sub.get_output().nc() == (*(truth + idx)).nc() + !*/ + }; + + template + using loss_mean_squared_per_channel_and_pixel = add_loss_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + class loss_dot_ { /*! WHAT THIS OBJECT REPRESENTS diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 4d4d98347..bdbc15119 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -2495,6 +2495,66 @@ namespace DLIB_TEST_MSG(error_after < 1e-6, "Autoencoder error after training = " << error_after); } +// ---------------------------------------------------------------------------------------- + + void test_loss_mean_squared_per_channel_and_pixel() + { + print_spinner(); + + const int num_samples = 1000; + const long num_channels = 2; + const long dimension = 3; + ::std::vector> inputs; + ::std::vector<::std::array, num_channels>> labels; + for (int i = 0; i < num_samples; ++i) + { + matrix x = matrix_cast(randm(5, dimension)); + matrix w = matrix_cast(randm(num_channels, 5)); + matrix y = w * x; + DLIB_CASSERT(y.nr() == num_channels); + ::std::array, num_channels> y_arr; + // convert y to an array of matrices + for (long c = 0; c < num_channels; ++c) + { + y_arr[c] = rowm(y, c); + } + inputs.push_back(::std::move(x)); + labels.push_back(::std::move(y_arr)); + } + + const long num_outputs = num_channels * dimension; + using net_type = loss_mean_squared_per_channel_and_pixel>>>>>>>; + net_type net; + + const auto compute_error = [&inputs, &labels, &net, num_channels]() + { + const auto out = net(inputs); + double error = 0.0; + for (size_t i = 0; i < out.size(); ++i) + { + for (size_t c = 0; c < num_channels; ++c) + { + error += mean(squared(out[i][c] - labels[i][c])); + } + } + return error / out.size() / num_channels; + }; + + const auto error_before = compute_error(); + dnn_trainer trainer(net); + trainer.set_learning_rate(0.1); + trainer.set_iterations_without_progress_threshold(500); + trainer.set_min_learning_rate(1e-6); + trainer.set_mini_batch_size(50); + trainer.train(inputs, labels); + const auto error_after = compute_error(); + DLIB_TEST_MSG(error_after < error_before, "multi channel error increased after training"); + } + // ---------------------------------------------------------------------------------------- void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task() @@ -3252,6 +3312,7 @@ namespace test_simple_linear_regression_with_mult_prev(); test_multioutput_linear_regression(); test_simple_autoencoder(); + test_loss_mean_squared_per_channel_and_pixel(); test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_outputs_on_trivial_task();