use running stats to track losses (#2041)

This commit is contained in:
Adrià Arrufat 2020-03-31 09:20:50 +09:00 committed by GitHub
parent 0057461a62
commit 57bb5eb58d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 5 deletions

View File

@ -181,6 +181,7 @@ int main(int argc, char** argv) try
const std::vector<float> fake_labels(minibatch_size, -1);
dlib::image_window win;
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor;
running_stats<double> g_loss, d_loss;
while (iteration < 50000)
{
// Train the discriminator with real images
@ -192,7 +193,7 @@ int main(int argc, char** argv) try
}
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin());
d_loss.add(discriminator.compute_loss(real_samples_tensor, real_labels.begin()));
discriminator.back_propagate_error(real_samples_tensor);
discriminator.update_parameters(d_solvers, learning_rate);
@ -210,7 +211,7 @@ int main(int argc, char** argv) try
// 4. finally train the discriminator and wait for the threading to stop. The following
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin());
d_loss.add(discriminator.compute_loss(fake_samples_tensor, fake_labels.begin()));
discriminator.back_propagate_error(fake_samples_tensor);
discriminator.update_parameters(d_solvers, learning_rate);
@ -223,7 +224,7 @@ int main(int argc, char** argv) try
// seen as test_one_step() plus the error back propagation.
// Forward the fake samples and compute the loss with real labels
const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin());
g_loss.add(discriminator.compute_loss(fake_samples_tensor, real_labels.begin()));
// Back propagate the error to fill the final data gradient
discriminator.back_propagate_error(fake_samples_tensor);
// Get the gradient that will tell the generator how to update itself
@ -238,10 +239,12 @@ int main(int argc, char** argv) try
serialize("dcgan_sync") << generator << discriminator << iteration;
std::cout <<
"step#: " << iteration <<
"\tdiscriminator loss: " << d_loss <<
"\tgenerator loss: " << g_loss << '\n';
"\tdiscriminator loss: " << d_loss.mean() * 2 <<
"\tgenerator loss: " << g_loss.mean() << '\n';
win.set_image(tile_images(fake_samples));
win.set_title("DCGAN step#: " + to_string(iteration));
d_loss.clear();
g_loss.clear();
}
}