From f308cac6df712da3619fb05b14f3345f0ec07b9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Sun, 26 Sep 2021 20:47:00 +0900 Subject: [PATCH] try to fix it reusing the code... not sure though --- dlib/dnn/trainer.h | 46 +++++++--------------------------------------- 1 file changed, 7 insertions(+), 39 deletions(-) diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index bf67ce33d..da714cb20 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -1194,6 +1194,9 @@ namespace dlib job.have_data.resize(devs); job.test_only = test_only; + // check if the iterator points to anything + const bool have_labels = sizeof(*lbegin) != 1; + // chop the data into devs blocks, each of about block_size elements. const double block_size = num / static_cast(devs); @@ -1211,7 +1214,8 @@ namespace dlib if (start < stop) { devices[i]->net.to_tensor(dbegin+start, dbegin+stop, job.t[i]); - job.labels[i].assign(lbegin+start, lbegin+stop); + if (have_labels) + job.labels[i].assign(lbegin+start, lbegin+stop); job.have_data[i] = true; } else @@ -1237,44 +1241,8 @@ namespace dlib data_iterator dend ) { - propagate_exception(); - size_t num = std::distance(dbegin, dend); - size_t devs = devices.size(); - job.t.resize(devs); - job.have_data.resize(devs); - job.test_only = test_only; - - // chop the data into devs blocks, each of about block_size elements. - const double block_size = num / static_cast(devs); - - const auto prev_dev = dlib::cuda::get_device(); - - double j = 0; - - for (size_t i = 0; i < devs; ++i) - { - dlib::cuda::set_device(devices[i]->device_id); - - const size_t start = static_cast(std::round(j)); - const size_t stop = static_cast(std::round(j + block_size)); - - if (start < stop) - { - devices[i]->net.to_tensor(dbegin+start, dbegin+stop, job.t[i]); - job.have_data[i] = true; - } - else - { - job.have_data[i] = false; - } - - j += block_size; - } - - DLIB_ASSERT(std::fabs(j - num) < 1e-10); - - dlib::cuda::set_device(prev_dev); - job_pipe.enqueue(job); + typename std::vector::iterator nothing; + send_job(test_only, dbegin, dend, nothing); } void print_progress()