mirror of https://github.com/davisking/dlib.git
Made add_loss_layer's batch operator() more general.
This commit is contained in:
parent
846a570479
commit
a9e1c9e457
|
@ -1923,12 +1923,13 @@ namespace dlib
|
|||
return temp_label;
|
||||
}
|
||||
|
||||
template <typename iterable_type>
|
||||
std::vector<label_type> operator() (
|
||||
const std::vector<input_type>& data,
|
||||
const iterable_type& data,
|
||||
size_t batch_size = 128
|
||||
)
|
||||
{
|
||||
std::vector<label_type> results(data.size());
|
||||
std::vector<label_type> results(std::distance(data.begin(), data.end()));
|
||||
auto o = results.begin();
|
||||
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
|
||||
{
|
||||
|
|
|
@ -688,13 +688,17 @@ namespace dlib
|
|||
label_type.
|
||||
!*/
|
||||
|
||||
template <typename iterable_type>
|
||||
std::vector<label_type> operator() (
|
||||
const std::vector<input_type>& data,
|
||||
const iterable_type& data,
|
||||
size_t batch_size = 128
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- batch_size > 0
|
||||
- data must have a .begin() and .end() that supply iterators over a
|
||||
sequence of input_type elements. E.g. data could have a type of
|
||||
std::vector<input_type>
|
||||
ensures
|
||||
- runs all the objects in data through the network and returns their
|
||||
predicted labels. This means this function returns a vector V such that:
|
||||
|
|
Loading…
Reference in New Issue