diff --git a/train.py b/train.py index 7bb13cc..6df18c0 100755 --- a/train.py +++ b/train.py @@ -131,7 +131,7 @@ def train(opt): start_time = time.time() best_accuracy = -1 - best_norm_ED = 1e+6 + best_norm_ED = -1 i = start_iter while(True): @@ -191,7 +191,7 @@ def train(opt): if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') - if current_norm_ED < best_norm_ED: + if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'