mirror of https://github.com/davisking/dlib.git
Learn only one class
This commit is contained in:
parent
edb083709a
commit
a841eff8a9
|
@ -528,21 +528,54 @@ std::vector<std::vector<truth_instance>> load_all_truth_instances(const std::vec
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void filter_listing(
|
||||
std::vector<image_info>& listing,
|
||||
std::vector<std::vector<truth_instance>>& truth_instances,
|
||||
const std::string& desired_classlabel
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(listing.size() == truth_instances.size());
|
||||
|
||||
std::vector<image_info> filtered_listing;
|
||||
std::vector<std::vector<truth_instance>> filtered_truth_instances;
|
||||
|
||||
const auto represents_desired_class = [desired_classlabel](const truth_instance& truth_instance) {
|
||||
return truth_instance.mmod_rect.label == desired_classlabel;
|
||||
};
|
||||
|
||||
for (int i = 0, end = listing.size(); i < end; ++i)
|
||||
{
|
||||
const auto has_desired_class = std::any_of(
|
||||
truth_instances[i].begin(),
|
||||
truth_instances[i].end(),
|
||||
represents_desired_class
|
||||
);
|
||||
|
||||
if (has_desired_class) {
|
||||
filtered_listing.push_back(listing[i]);
|
||||
filtered_truth_instances.push_back(truth_instances[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(listing, filtered_listing);
|
||||
std::swap(truth_instances, filtered_truth_instances);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) try
|
||||
{
|
||||
if (argc < 2 || argc > 4)
|
||||
if (argc < 2 || argc > 5)
|
||||
{
|
||||
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
|
||||
cout << endl;
|
||||
cout << "You call this program like this: " << endl;
|
||||
cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size]" << endl;
|
||||
cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-label]" << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;
|
||||
|
||||
const auto listing = get_pascal_voc2012_train_listing(argv[1]);
|
||||
cout << "images in dataset: " << listing.size() << endl;
|
||||
auto listing = get_pascal_voc2012_train_listing(argv[1]);
|
||||
cout << "images in entire dataset: " << listing.size() << endl;
|
||||
if (listing.size() == 0)
|
||||
{
|
||||
cout << "Didn't find the VOC2012 dataset. " << endl;
|
||||
|
@ -550,16 +583,23 @@ int main(int argc, char** argv) try
|
|||
}
|
||||
|
||||
// mini-batches smaller than the default can be used with GPUs having less memory
|
||||
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 40;
|
||||
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 60;
|
||||
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 25;
|
||||
cout << "det mini-batch size: " << det_minibatch_size << endl;
|
||||
cout << "seg mini-batch size: " << seg_minibatch_size << endl;
|
||||
|
||||
const std::string desired_classlabel = argc >= 5 ? argv[4] : "sheep";
|
||||
cout << "desired classlabel: " << desired_classlabel << endl;
|
||||
|
||||
// extract the MMOD rects
|
||||
cout << "\nExtracting all truth instances...";
|
||||
const auto truth_instances = load_all_truth_instances(listing);
|
||||
cout << endl << "Extracting all truth instances...";
|
||||
auto truth_instances = load_all_truth_instances(listing);
|
||||
cout << " Done!" << endl << endl;
|
||||
|
||||
filter_listing(listing, truth_instances, desired_classlabel);
|
||||
|
||||
cout << "images in dataset filtered by class: " << listing.size() << endl << endl;
|
||||
|
||||
// First train a detection network (loss_mmod), and then a mask segmentation network (loss_log_per_pixel)
|
||||
const auto det_net = train_detection_network (listing, truth_instances, det_minibatch_size);
|
||||
const auto seg_net = train_segmentation_network (listing, truth_instances, seg_minibatch_size);
|
||||
|
|
Loading…
Reference in New Issue