From 846a5704795e5e3ad1a7beac2dee3790c9211a46 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 23 Jan 2016 12:06:51 -0500 Subject: [PATCH] Added an overload of operator() that lets you easily run a network on an entire std::vector of objects. --- dlib/dnn/core.h | 16 ++++++++++++++++ dlib/dnn/core_abstract.h | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/dlib/dnn/core.h b/dlib/dnn/core.h index 909be0b72..95bde45f1 100644 --- a/dlib/dnn/core.h +++ b/dlib/dnn/core.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "tensor_tools.h" @@ -1922,6 +1923,21 @@ namespace dlib return temp_label; } + std::vector operator() ( + const std::vector& data, + size_t batch_size = 128 + ) + { + std::vector results(data.size()); + auto o = results.begin(); + for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size) + { + auto end = std::min(i+batch_size, data.end()); + (*this)(i, end, o); + } + return results; + } + template double compute_loss ( const tensor& x, diff --git a/dlib/dnn/core_abstract.h b/dlib/dnn/core_abstract.h index 50d7444ed..3f893ac49 100644 --- a/dlib/dnn/core_abstract.h +++ b/dlib/dnn/core_abstract.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "../rand.h" @@ -687,6 +688,25 @@ namespace dlib label_type. !*/ + std::vector operator() ( + const std::vector& data, + size_t batch_size = 128 + ); + /*! + requires + - batch_size > 0 + ensures + - runs all the objects in data through the network and returns their + predicted labels. This means this function returns a vector V such that: + - V.size() == data.size() + - for all valid i: V[i] == the predicted label of data[i]. + - Elements of data are run through the network in batches of batch_size + items. Using a batch_size > 1 can be faster because it better exploits + the available hardware parallelism. + - loss_details().to_label() is used to convert the network output into a + label_type. + !*/ + // ------------- template