mirror of https://github.com/davisking/dlib.git
Added an overload of operator() that lets you easily run a network on an
entire std::vector of objects.
This commit is contained in:
parent
93ab80c758
commit
846a570479
|
@ -15,6 +15,7 @@
|
|||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "tensor_tools.h"
|
||||
|
||||
|
||||
|
@ -1922,6 +1923,21 @@ namespace dlib
|
|||
return temp_label;
|
||||
}
|
||||
|
||||
std::vector<label_type> operator() (
|
||||
const std::vector<input_type>& data,
|
||||
size_t batch_size = 128
|
||||
)
|
||||
{
|
||||
std::vector<label_type> results(data.size());
|
||||
auto o = results.begin();
|
||||
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
|
||||
{
|
||||
auto end = std::min(i+batch_size, data.end());
|
||||
(*this)(i, end, o);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
template <typename label_iterator>
|
||||
double compute_loss (
|
||||
const tensor& x,
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "../rand.h"
|
||||
|
||||
|
||||
|
@ -687,6 +688,25 @@ namespace dlib
|
|||
label_type.
|
||||
!*/
|
||||
|
||||
std::vector<label_type> operator() (
|
||||
const std::vector<input_type>& data,
|
||||
size_t batch_size = 128
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- batch_size > 0
|
||||
ensures
|
||||
- runs all the objects in data through the network and returns their
|
||||
predicted labels. This means this function returns a vector V such that:
|
||||
- V.size() == data.size()
|
||||
- for all valid i: V[i] == the predicted label of data[i].
|
||||
- Elements of data are run through the network in batches of batch_size
|
||||
items. Using a batch_size > 1 can be faster because it better exploits
|
||||
the available hardware parallelism.
|
||||
- loss_details().to_label() is used to convert the network output into a
|
||||
label_type.
|
||||
!*/
|
||||
|
||||
// -------------
|
||||
|
||||
template <typename label_iterator>
|
||||
|
|
Loading…
Reference in New Issue