Added an overload of operator() that lets you easily run a network on an

entire std::vector of objects.
This commit is contained in:
Davis King 2016-01-23 12:06:51 -05:00
parent 93ab80c758
commit 846a570479
2 changed files with 36 additions and 0 deletions

View File

@ -15,6 +15,7 @@
#include <utility>
#include <tuple>
#include <cmath>
#include <vector>
#include "tensor_tools.h"
@ -1922,6 +1923,21 @@ namespace dlib
return temp_label;
}
std::vector<label_type> operator() (
const std::vector<input_type>& data,
size_t batch_size = 128
)
{
std::vector<label_type> results(data.size());
auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{
auto end = std::min(i+batch_size, data.end());
(*this)(i, end, o);
}
return results;
}
template <typename label_iterator>
double compute_loss (
const tensor& x,

View File

@ -7,6 +7,7 @@
#include <memory>
#include <type_traits>
#include <tuple>
#include <vector>
#include "../rand.h"
@ -687,6 +688,25 @@ namespace dlib
label_type.
!*/
std::vector<label_type> operator() (
const std::vector<input_type>& data,
size_t batch_size = 128
);
/*!
requires
- batch_size > 0
ensures
- runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that:
- V.size() == data.size()
- for all valid i: V[i] == the predicted label of data[i].
- Elements of data are run through the network in batches of batch_size
items. Using a batch_size > 1 can be faster because it better exploits
the available hardware parallelism.
- loss_details().to_label() is used to convert the network output into a
label_type.
!*/
// -------------
template <typename label_iterator>