Made the multi_device_tensor_averager not assume the size of the tensors is

known at set() time.
This commit is contained in:
Davis King 2016-04-29 06:55:04 -04:00
parent b85688acec
commit 0d6e3f12d6
1 changed files with 1 additions and 8 deletions

View File

@ -1067,14 +1067,6 @@ namespace dlib { namespace tt
epa.emplace_back(new enable_peer_access(*g[0], *g[i]));
}
}
// If there are multiple groups then we need to use the accum_buffer space
// when talking across groups. So allocate that buffer now.
if (accessible_groups.size() > 1)
{
raii_set_device set_dev(*accessible_groups[0][0]);
accum_buffer.copy_size(*accessible_groups[0][0]);
}
}
void average()
@ -1108,6 +1100,7 @@ namespace dlib { namespace tt
{
tensor& total_avg = *accessible_groups[0][0];
raii_set_device set_dev(total_avg);
accum_buffer.copy_size(total_avg);
// now we need to average things across groups
for (size_t i = 1; i < accessible_groups.size(); ++i)
{