diff --git a/examples/dnn_instance_segmentation_train_ex.cpp b/examples/dnn_instance_segmentation_train_ex.cpp index 619d4bb1a..99187bb43 100644 --- a/examples/dnn_instance_segmentation_train_ex.cpp +++ b/examples/dnn_instance_segmentation_train_ex.cpp @@ -528,21 +528,54 @@ std::vector> load_all_truth_instances(const std::vec // ---------------------------------------------------------------------------------------- +void filter_listing( + std::vector& listing, + std::vector>& truth_instances, + const std::string& desired_classlabel +) +{ + DLIB_CASSERT(listing.size() == truth_instances.size()); + + std::vector filtered_listing; + std::vector> filtered_truth_instances; + + const auto represents_desired_class = [desired_classlabel](const truth_instance& truth_instance) { + return truth_instance.mmod_rect.label == desired_classlabel; + }; + + for (int i = 0, end = listing.size(); i < end; ++i) + { + const auto has_desired_class = std::any_of( + truth_instances[i].begin(), + truth_instances[i].end(), + represents_desired_class + ); + + if (has_desired_class) { + filtered_listing.push_back(listing[i]); + filtered_truth_instances.push_back(truth_instances[i]); + } + } + + std::swap(listing, filtered_listing); + std::swap(truth_instances, filtered_truth_instances); +} + int main(int argc, char** argv) try { - if (argc < 2 || argc > 4) + if (argc < 2 || argc > 5) { cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl; cout << endl; cout << "You call this program like this: " << endl; - cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size]" << endl; + cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-label]" << endl; return 1; } cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl; - const auto listing = get_pascal_voc2012_train_listing(argv[1]); - cout << "images in dataset: " << listing.size() << endl; + auto listing = get_pascal_voc2012_train_listing(argv[1]); + cout << "images in entire dataset: " << listing.size() << endl; if (listing.size() == 0) { cout << "Didn't find the VOC2012 dataset. " << endl; @@ -550,16 +583,23 @@ int main(int argc, char** argv) try } // mini-batches smaller than the default can be used with GPUs having less memory - const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 40; + const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 60; const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 25; cout << "det mini-batch size: " << det_minibatch_size << endl; cout << "seg mini-batch size: " << seg_minibatch_size << endl; + const std::string desired_classlabel = argc >= 5 ? argv[4] : "sheep"; + cout << "desired classlabel: " << desired_classlabel << endl; + // extract the MMOD rects - cout << "\nExtracting all truth instances..."; - const auto truth_instances = load_all_truth_instances(listing); + cout << endl << "Extracting all truth instances..."; + auto truth_instances = load_all_truth_instances(listing); cout << " Done!" << endl << endl; + filter_listing(listing, truth_instances, desired_classlabel); + + cout << "images in dataset filtered by class: " << listing.size() << endl << endl; + // First train a detection network (loss_mmod), and then a mask segmentation network (loss_log_per_pixel) const auto det_net = train_detection_network (listing, truth_instances, det_minibatch_size); const auto seg_net = train_segmentation_network (listing, truth_instances, seg_minibatch_size);