mirror of https://github.com/davisking/dlib.git
Perform conv with bias for specific algo with cudnn (#2839)
This commit is contained in:
parent
7b6021eee7
commit
c7b2917498
|
@ -928,7 +928,7 @@ namespace dlib
|
|||
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
||||
}
|
||||
#endif
|
||||
backward_filters_algo = backward_filters_best_algo;
|
||||
backward_filters_algo = backward_filters_best_algo;
|
||||
|
||||
// Save this algorithm selection in the cache
|
||||
config_to_algo_cache[cache_key] = std::make_tuple(forward_algo, backward_data_algo, backward_filters_algo);
|
||||
|
@ -1177,6 +1177,19 @@ namespace dlib
|
|||
const tensor& biases
|
||||
)
|
||||
{
|
||||
|
||||
// Function cudnnConvolutionBiasActivationForward should only be called with CUDNN_ACTIVATION_IDENTITY when
|
||||
// the chosen forward algorithm is CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, as cuDNN documentation explicitly says.
|
||||
// In case the algorithm is different, perform the forward pass and bias addition separately.
|
||||
if (forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
|
||||
{
|
||||
(*this)(add_to_output, output, data, filters);
|
||||
|
||||
tt::add(1, output, 1, biases);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
DLIB_CASSERT(is_same_object(output,data) == false);
|
||||
DLIB_CASSERT(is_same_object(output,filters) == false);
|
||||
DLIB_CASSERT(filters.k() == data.k());
|
||||
|
|
|
@ -233,17 +233,19 @@ namespace dlib
|
|||
padding_y_,
|
||||
padding_x_);
|
||||
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0));
|
||||
|
||||
// For some reason, doing this is sometimes slower than two separate calls
|
||||
// conv(false, output,
|
||||
// sub.get_output(),
|
||||
// filters(params,0),
|
||||
// biases(params, filters.size()));
|
||||
if (use_bias)
|
||||
tt::add(1,output,1,biases(params,filters.size()));
|
||||
{
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0),
|
||||
biases(params, filters.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
|
|
Loading…
Reference in New Issue