Fixed some resource leaks. Also fixed max_pool so it does exactly what the

spec says it should.
This commit is contained in:
Davis King 2015-12-12 12:52:32 -05:00
parent cbd57be677
commit 7ae43ae2d5
2 changed files with 42 additions and 18 deletions

View File

@ -480,6 +480,7 @@ namespace dlib
catch(...)
{
clear();
throw;
}
}
@ -581,7 +582,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_pool::max_pool (
) : handle(nullptr),stride_y(0),stride_x(0)
) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0)
{
}
@ -598,31 +599,52 @@ namespace dlib
if (handle)
cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
handle = nullptr;
window_height = 0;
window_width = 0;
stride_y = 0;
stride_x = 0;
}
void max_pool::
setup(
int window_height,
int window_width,
int window_height_,
int window_width_,
int stride_y_,
int stride_x_
)
{
stride_x = stride_x_;
stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc;
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
handle = poolingDesc;
if (window_height == window_height_ &&
window_width == window_width_ &&
stride_y == stride_y_ &&
stride_x == stride_x_ )
{
return;
}
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
CUDNN_POOLING_MAX,
window_height,
window_width,
0,0, // no padding
stride_y,
stride_x));
clear();
try
{
window_height = window_height_;
window_width = window_width_;
stride_x = stride_x_;
stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc;
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
handle = poolingDesc;
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
CUDNN_POOLING_MAX,
window_height,
window_width,
0,0, // no padding
stride_y,
stride_x));
}
catch(...)
{
clear();
throw;
}
}
void max_pool::
@ -649,8 +671,8 @@ namespace dlib
DLIB_CASSERT(dest.num_samples() == src.num_samples(),"");
DLIB_CASSERT(dest.k() == src.k(),"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,"");
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y, stride_y << ", " << dest.nr() << " " << src.nr()/stride_y);
DLIB_CASSERT(dest.nc() == src.nc()/stride_x, stride_x << ", " << dest.nc() << " " << src.nc()/stride_x);
CHECK_CUDNN(cudnnPoolingForward(context(),
(const cudnnPoolingDescriptor_t)handle,
@ -673,7 +695,7 @@ namespace dlib
DLIB_CASSERT(have_same_dimensions(src,grad),"");
const float alpha = 1;
const float beta = 0;
const float beta = 1;
CHECK_CUDNN(cudnnPoolingBackward(context(),
(const cudnnPoolingDescriptor_t)handle,
&alpha,

View File

@ -328,6 +328,8 @@ namespace dlib
private:
void* handle;
int window_height;
int window_width;
int stride_y;
int stride_x;
};