Use separate synchronization file for each seg net of each class

This commit is contained in:
Juha Reunanen 2019-10-29 14:37:36 +02:00
parent 2f17289803
commit 030c468003
1 changed files with 13 additions and 7 deletions

View File

@ -312,7 +312,8 @@ matrix<uint16_t> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_i
seg_bnet_type train_segmentation_network(
const std::vector<image_info>& listing,
const std::vector<std::vector<truth_instance>>& truth_instances,
unsigned int seg_minibatch_size
unsigned int seg_minibatch_size,
const std::string& classlabel
)
{
seg_bnet_type seg_net;
@ -321,10 +322,15 @@ seg_bnet_type train_segmentation_network(
const double weight_decay = 0.0001;
const double momentum = 0.9;
const std::string synchronization_file_name
= "pascal_voc2012_seg_trainer_state_file"
+ (classlabel.empty() ? "" : ("_" + classlabel))
+ ".dat";
dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum));
seg_trainer.be_verbose();
seg_trainer.set_learning_rate(initial_learning_rate);
seg_trainer.set_synchronization_file("pascal_voc2012_seg_trainer_state_file.dat", std::chrono::minutes(10));
seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10));
seg_trainer.set_iterations_without_progress_threshold(5000);
set_all_bn_running_stats_window_sizes(seg_net, 1000);
@ -626,10 +632,10 @@ int main(int argc, char** argv) try
filter_listing(listing, truth_instances, desired_classlabels);
cout << "images in dataset filtered by class: " << listing.size() << endl << endl;
cout << "images in dataset filtered by class: " << listing.size() << endl;
// First train a detection network (loss_mmod), and then a mask segmentation network (loss_log_per_pixel)
cout << "Training detector network:" << endl;
cout << endl << "Training detector network:" << endl;
const auto det_net = train_detection_network (listing, truth_instances, det_minibatch_size);
std::map<std::string, seg_bnet_type> seg_nets_by_class;
@ -644,16 +650,16 @@ int main(int argc, char** argv) try
auto truth_instances_for_classlabel = truth_instances;
filter_listing(listing_for_classlabel, truth_instances_for_classlabel, { classlabel });
cout << "Training segmentation network for class " << classlabel << ":" << endl;
cout << endl << "Training segmentation network for class " << classlabel << ":" << endl;
seg_nets_by_class[classlabel] = train_segmentation_network(
listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size
listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size, classlabel
);
}
}
else
{
cout << "Training a single segmentation network:" << endl;
seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size);
seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size, "");
}
cout << "Saving networks" << endl;