mirror of https://github.com/davisking/dlib.git
use running stats to track losses (#2041)
This commit is contained in:
parent
0057461a62
commit
57bb5eb58d
|
@ -181,6 +181,7 @@ int main(int argc, char** argv) try
|
||||||
const std::vector<float> fake_labels(minibatch_size, -1);
|
const std::vector<float> fake_labels(minibatch_size, -1);
|
||||||
dlib::image_window win;
|
dlib::image_window win;
|
||||||
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor;
|
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor;
|
||||||
|
running_stats<double> g_loss, d_loss;
|
||||||
while (iteration < 50000)
|
while (iteration < 50000)
|
||||||
{
|
{
|
||||||
// Train the discriminator with real images
|
// Train the discriminator with real images
|
||||||
|
@ -192,7 +193,7 @@ int main(int argc, char** argv) try
|
||||||
}
|
}
|
||||||
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
|
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
|
||||||
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
|
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
|
||||||
double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin());
|
d_loss.add(discriminator.compute_loss(real_samples_tensor, real_labels.begin()));
|
||||||
discriminator.back_propagate_error(real_samples_tensor);
|
discriminator.back_propagate_error(real_samples_tensor);
|
||||||
discriminator.update_parameters(d_solvers, learning_rate);
|
discriminator.update_parameters(d_solvers, learning_rate);
|
||||||
|
|
||||||
|
@ -210,7 +211,7 @@ int main(int argc, char** argv) try
|
||||||
// 4. finally train the discriminator and wait for the threading to stop. The following
|
// 4. finally train the discriminator and wait for the threading to stop. The following
|
||||||
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
|
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
|
||||||
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
|
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
|
||||||
d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin());
|
d_loss.add(discriminator.compute_loss(fake_samples_tensor, fake_labels.begin()));
|
||||||
discriminator.back_propagate_error(fake_samples_tensor);
|
discriminator.back_propagate_error(fake_samples_tensor);
|
||||||
discriminator.update_parameters(d_solvers, learning_rate);
|
discriminator.update_parameters(d_solvers, learning_rate);
|
||||||
|
|
||||||
|
@ -223,7 +224,7 @@ int main(int argc, char** argv) try
|
||||||
// seen as test_one_step() plus the error back propagation.
|
// seen as test_one_step() plus the error back propagation.
|
||||||
|
|
||||||
// Forward the fake samples and compute the loss with real labels
|
// Forward the fake samples and compute the loss with real labels
|
||||||
const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin());
|
g_loss.add(discriminator.compute_loss(fake_samples_tensor, real_labels.begin()));
|
||||||
// Back propagate the error to fill the final data gradient
|
// Back propagate the error to fill the final data gradient
|
||||||
discriminator.back_propagate_error(fake_samples_tensor);
|
discriminator.back_propagate_error(fake_samples_tensor);
|
||||||
// Get the gradient that will tell the generator how to update itself
|
// Get the gradient that will tell the generator how to update itself
|
||||||
|
@ -238,10 +239,12 @@ int main(int argc, char** argv) try
|
||||||
serialize("dcgan_sync") << generator << discriminator << iteration;
|
serialize("dcgan_sync") << generator << discriminator << iteration;
|
||||||
std::cout <<
|
std::cout <<
|
||||||
"step#: " << iteration <<
|
"step#: " << iteration <<
|
||||||
"\tdiscriminator loss: " << d_loss <<
|
"\tdiscriminator loss: " << d_loss.mean() * 2 <<
|
||||||
"\tgenerator loss: " << g_loss << '\n';
|
"\tgenerator loss: " << g_loss.mean() << '\n';
|
||||||
win.set_image(tile_images(fake_samples));
|
win.set_image(tile_images(fake_samples));
|
||||||
win.set_title("DCGAN step#: " + to_string(iteration));
|
win.set_title("DCGAN step#: " + to_string(iteration));
|
||||||
|
d_loss.clear();
|
||||||
|
g_loss.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue