diff --git a/dlib/dnn/tensor_tools.h b/dlib/dnn/tensor_tools.h index d16257004..e14fb307d 100644 --- a/dlib/dnn/tensor_tools.h +++ b/dlib/dnn/tensor_tools.h @@ -1067,14 +1067,6 @@ namespace dlib { namespace tt epa.emplace_back(new enable_peer_access(*g[0], *g[i])); } } - - // If there are multiple groups then we need to use the accum_buffer space - // when talking across groups. So allocate that buffer now. - if (accessible_groups.size() > 1) - { - raii_set_device set_dev(*accessible_groups[0][0]); - accum_buffer.copy_size(*accessible_groups[0][0]); - } } void average() @@ -1108,6 +1100,7 @@ namespace dlib { namespace tt { tensor& total_avg = *accessible_groups[0][0]; raii_set_device set_dev(total_avg); + accum_buffer.copy_size(total_avg); // now we need to average things across groups for (size_t i = 1; i < accessible_groups.size(); ++i) {