Made add_loss_layer's batch operator() more general.

This commit is contained in:
Davis King 2016-01-23 19:48:47 -05:00
parent 846a570479
commit a9e1c9e457
2 changed files with 8 additions and 3 deletions

View File

@ -1923,12 +1923,13 @@ namespace dlib
return temp_label;
}
template <typename iterable_type>
std::vector<label_type> operator() (
const std::vector<input_type>& data,
const iterable_type& data,
size_t batch_size = 128
)
{
std::vector<label_type> results(data.size());
std::vector<label_type> results(std::distance(data.begin(), data.end()));
auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{

View File

@ -688,13 +688,17 @@ namespace dlib
label_type.
!*/
template <typename iterable_type>
std::vector<label_type> operator() (
const std::vector<input_type>& data,
const iterable_type& data,
size_t batch_size = 128
);
/*!
requires
- batch_size > 0
- data must have a .begin() and .end() that supply iterators over a
sequence of input_type elements. E.g. data could have a type of
std::vector<input_type>
ensures
- runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that: