Made add_loss_layer's batch operator() more general.

This commit is contained in:
Davis King 2016-01-23 19:48:47 -05:00
parent 846a570479
commit a9e1c9e457
2 changed files with 8 additions and 3 deletions

View File

@ -1923,12 +1923,13 @@ namespace dlib
return temp_label; return temp_label;
} }
template <typename iterable_type>
std::vector<label_type> operator() ( std::vector<label_type> operator() (
const std::vector<input_type>& data, const iterable_type& data,
size_t batch_size = 128 size_t batch_size = 128
) )
{ {
std::vector<label_type> results(data.size()); std::vector<label_type> results(std::distance(data.begin(), data.end()));
auto o = results.begin(); auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size) for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{ {

View File

@ -688,13 +688,17 @@ namespace dlib
label_type. label_type.
!*/ !*/
template <typename iterable_type>
std::vector<label_type> operator() ( std::vector<label_type> operator() (
const std::vector<input_type>& data, const iterable_type& data,
size_t batch_size = 128 size_t batch_size = 128
); );
/*! /*!
requires requires
- batch_size > 0 - batch_size > 0
- data must have a .begin() and .end() that supply iterators over a
sequence of input_type elements. E.g. data could have a type of
std::vector<input_type>
ensures ensures
- runs all the objects in data through the network and returns their - runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that: predicted labels. This means this function returns a vector V such that: