mirror of https://github.com/davisking/dlib.git
Made the multi_device_tensor_averager not assume the size of the tensors is
known at set() time.
This commit is contained in:
parent
b85688acec
commit
0d6e3f12d6
|
@ -1067,14 +1067,6 @@ namespace dlib { namespace tt
|
|||
epa.emplace_back(new enable_peer_access(*g[0], *g[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// If there are multiple groups then we need to use the accum_buffer space
|
||||
// when talking across groups. So allocate that buffer now.
|
||||
if (accessible_groups.size() > 1)
|
||||
{
|
||||
raii_set_device set_dev(*accessible_groups[0][0]);
|
||||
accum_buffer.copy_size(*accessible_groups[0][0]);
|
||||
}
|
||||
}
|
||||
|
||||
void average()
|
||||
|
@ -1108,6 +1100,7 @@ namespace dlib { namespace tt
|
|||
{
|
||||
tensor& total_avg = *accessible_groups[0][0];
|
||||
raii_set_device set_dev(total_avg);
|
||||
accum_buffer.copy_size(total_avg);
|
||||
// now we need to average things across groups
|
||||
for (size_t i = 1; i < accessible_groups.size(); ++i)
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue