diff --git a/dlib/dnn/cudnn_dlibapi.cpp b/dlib/dnn/cudnn_dlibapi.cpp index e36f64fdc..5e0537267 100644 --- a/dlib/dnn/cudnn_dlibapi.cpp +++ b/dlib/dnn/cudnn_dlibapi.cpp @@ -480,6 +480,7 @@ namespace dlib catch(...) { clear(); + throw; } } @@ -581,7 +582,7 @@ namespace dlib // ------------------------------------------------------------------------------------ max_pool::max_pool ( - ) : handle(nullptr),stride_y(0),stride_x(0) + ) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0) { } @@ -598,31 +599,52 @@ namespace dlib if (handle) cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle); handle = nullptr; + window_height = 0; + window_width = 0; stride_y = 0; stride_x = 0; } void max_pool:: setup( - int window_height, - int window_width, + int window_height_, + int window_width_, int stride_y_, int stride_x_ ) { - stride_x = stride_x_; - stride_y = stride_y_; - cudnnPoolingDescriptor_t poolingDesc; - CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc)); - handle = poolingDesc; + if (window_height == window_height_ && + window_width == window_width_ && + stride_y == stride_y_ && + stride_x == stride_x_ ) + { + return; + } - CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, - CUDNN_POOLING_MAX, - window_height, - window_width, - 0,0, // no padding - stride_y, - stride_x)); + clear(); + try + { + window_height = window_height_; + window_width = window_width_; + stride_x = stride_x_; + stride_y = stride_y_; + cudnnPoolingDescriptor_t poolingDesc; + CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc)); + handle = poolingDesc; + + CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, + CUDNN_POOLING_MAX, + window_height, + window_width, + 0,0, // no padding + stride_y, + stride_x)); + } + catch(...) + { + clear(); + throw; + } } void max_pool:: @@ -649,8 +671,8 @@ namespace dlib DLIB_CASSERT(dest.num_samples() == src.num_samples(),""); DLIB_CASSERT(dest.k() == src.k(),""); - DLIB_CASSERT(dest.nr() == src.nr()/stride_y,""); - DLIB_CASSERT(dest.nc() == src.nc()/stride_x,""); + DLIB_CASSERT(dest.nr() == src.nr()/stride_y, stride_y << ", " << dest.nr() << " " << src.nr()/stride_y); + DLIB_CASSERT(dest.nc() == src.nc()/stride_x, stride_x << ", " << dest.nc() << " " << src.nc()/stride_x); CHECK_CUDNN(cudnnPoolingForward(context(), (const cudnnPoolingDescriptor_t)handle, @@ -673,7 +695,7 @@ namespace dlib DLIB_CASSERT(have_same_dimensions(src,grad),""); const float alpha = 1; - const float beta = 0; + const float beta = 1; CHECK_CUDNN(cudnnPoolingBackward(context(), (const cudnnPoolingDescriptor_t)handle, &alpha, diff --git a/dlib/dnn/cudnn_dlibapi.h b/dlib/dnn/cudnn_dlibapi.h index 55f6f41df..a95d0b300 100644 --- a/dlib/dnn/cudnn_dlibapi.h +++ b/dlib/dnn/cudnn_dlibapi.h @@ -328,6 +328,8 @@ namespace dlib private: void* handle; + int window_height; + int window_width; int stride_y; int stride_x; };