mirror of https://github.com/davisking/dlib.git
Made add_loss_layer's batch operator() more general.
This commit is contained in:
parent
846a570479
commit
a9e1c9e457
|
@ -1923,12 +1923,13 @@ namespace dlib
|
||||||
return temp_label;
|
return temp_label;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename iterable_type>
|
||||||
std::vector<label_type> operator() (
|
std::vector<label_type> operator() (
|
||||||
const std::vector<input_type>& data,
|
const iterable_type& data,
|
||||||
size_t batch_size = 128
|
size_t batch_size = 128
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
std::vector<label_type> results(data.size());
|
std::vector<label_type> results(std::distance(data.begin(), data.end()));
|
||||||
auto o = results.begin();
|
auto o = results.begin();
|
||||||
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
|
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
|
||||||
{
|
{
|
||||||
|
|
|
@ -688,13 +688,17 @@ namespace dlib
|
||||||
label_type.
|
label_type.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
template <typename iterable_type>
|
||||||
std::vector<label_type> operator() (
|
std::vector<label_type> operator() (
|
||||||
const std::vector<input_type>& data,
|
const iterable_type& data,
|
||||||
size_t batch_size = 128
|
size_t batch_size = 128
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
requires
|
requires
|
||||||
- batch_size > 0
|
- batch_size > 0
|
||||||
|
- data must have a .begin() and .end() that supply iterators over a
|
||||||
|
sequence of input_type elements. E.g. data could have a type of
|
||||||
|
std::vector<input_type>
|
||||||
ensures
|
ensures
|
||||||
- runs all the objects in data through the network and returns their
|
- runs all the objects in data through the network and returns their
|
||||||
predicted labels. This means this function returns a vector V such that:
|
predicted labels. This means this function returns a vector V such that:
|
||||||
|
|
Loading…
Reference in New Issue