mirror of https://github.com/davisking/dlib.git
use running stats to track losses (#2041)
This commit is contained in:
parent
0057461a62
commit
57bb5eb58d
|
@ -181,6 +181,7 @@ int main(int argc, char** argv) try
|
|||
const std::vector<float> fake_labels(minibatch_size, -1);
|
||||
dlib::image_window win;
|
||||
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor;
|
||||
running_stats<double> g_loss, d_loss;
|
||||
while (iteration < 50000)
|
||||
{
|
||||
// Train the discriminator with real images
|
||||
|
@ -192,7 +193,7 @@ int main(int argc, char** argv) try
|
|||
}
|
||||
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
|
||||
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
|
||||
double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin());
|
||||
d_loss.add(discriminator.compute_loss(real_samples_tensor, real_labels.begin()));
|
||||
discriminator.back_propagate_error(real_samples_tensor);
|
||||
discriminator.update_parameters(d_solvers, learning_rate);
|
||||
|
||||
|
@ -210,7 +211,7 @@ int main(int argc, char** argv) try
|
|||
// 4. finally train the discriminator and wait for the threading to stop. The following
|
||||
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
|
||||
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
|
||||
d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin());
|
||||
d_loss.add(discriminator.compute_loss(fake_samples_tensor, fake_labels.begin()));
|
||||
discriminator.back_propagate_error(fake_samples_tensor);
|
||||
discriminator.update_parameters(d_solvers, learning_rate);
|
||||
|
||||
|
@ -223,7 +224,7 @@ int main(int argc, char** argv) try
|
|||
// seen as test_one_step() plus the error back propagation.
|
||||
|
||||
// Forward the fake samples and compute the loss with real labels
|
||||
const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin());
|
||||
g_loss.add(discriminator.compute_loss(fake_samples_tensor, real_labels.begin()));
|
||||
// Back propagate the error to fill the final data gradient
|
||||
discriminator.back_propagate_error(fake_samples_tensor);
|
||||
// Get the gradient that will tell the generator how to update itself
|
||||
|
@ -238,10 +239,12 @@ int main(int argc, char** argv) try
|
|||
serialize("dcgan_sync") << generator << discriminator << iteration;
|
||||
std::cout <<
|
||||
"step#: " << iteration <<
|
||||
"\tdiscriminator loss: " << d_loss <<
|
||||
"\tgenerator loss: " << g_loss << '\n';
|
||||
"\tdiscriminator loss: " << d_loss.mean() * 2 <<
|
||||
"\tgenerator loss: " << g_loss.mean() << '\n';
|
||||
win.set_image(tile_images(fake_samples));
|
||||
win.set_title("DCGAN step#: " + to_string(iteration));
|
||||
d_loss.clear();
|
||||
g_loss.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue