From 57bb5eb58d1f6ef6424d7166e5036aee38c1ff7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Tue, 31 Mar 2020 09:20:50 +0900 Subject: [PATCH] use running stats to track losses (#2041) --- examples/dnn_dcgan_train_ex.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/dnn_dcgan_train_ex.cpp b/examples/dnn_dcgan_train_ex.cpp index 8eca83bdd..e9fb87d88 100644 --- a/examples/dnn_dcgan_train_ex.cpp +++ b/examples/dnn_dcgan_train_ex.cpp @@ -181,6 +181,7 @@ int main(int argc, char** argv) try const std::vector fake_labels(minibatch_size, -1); dlib::image_window win; resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor; + running_stats g_loss, d_loss; while (iteration < 50000) { // Train the discriminator with real images @@ -192,7 +193,7 @@ int main(int argc, char** argv) try } // The following lines are equivalent to calling train_one_step(real_samples, real_labels) discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor); - double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin()); + d_loss.add(discriminator.compute_loss(real_samples_tensor, real_labels.begin())); discriminator.back_propagate_error(real_samples_tensor); discriminator.update_parameters(d_solvers, learning_rate); @@ -210,7 +211,7 @@ int main(int argc, char** argv) try // 4. finally train the discriminator and wait for the threading to stop. The following // lines are equivalent to calling train_one_step(fake_samples, fake_labels) discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor); - d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin()); + d_loss.add(discriminator.compute_loss(fake_samples_tensor, fake_labels.begin())); discriminator.back_propagate_error(fake_samples_tensor); discriminator.update_parameters(d_solvers, learning_rate); @@ -223,7 +224,7 @@ int main(int argc, char** argv) try // seen as test_one_step() plus the error back propagation. // Forward the fake samples and compute the loss with real labels - const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin()); + g_loss.add(discriminator.compute_loss(fake_samples_tensor, real_labels.begin())); // Back propagate the error to fill the final data gradient discriminator.back_propagate_error(fake_samples_tensor); // Get the gradient that will tell the generator how to update itself @@ -238,10 +239,12 @@ int main(int argc, char** argv) try serialize("dcgan_sync") << generator << discriminator << iteration; std::cout << "step#: " << iteration << - "\tdiscriminator loss: " << d_loss << - "\tgenerator loss: " << g_loss << '\n'; + "\tdiscriminator loss: " << d_loss.mean() * 2 << + "\tgenerator loss: " << g_loss.mean() << '\n'; win.set_image(tile_images(fake_samples)); win.set_title("DCGAN step#: " + to_string(iteration)); + d_loss.clear(); + g_loss.clear(); } }