mirror of https://github.com/davisking/dlib.git
Fixed some resource leaks. Also fixed max_pool so it does exactly what the
spec says it should.
This commit is contained in:
parent
cbd57be677
commit
7ae43ae2d5
|
@ -480,6 +480,7 @@ namespace dlib
|
|||
catch(...)
|
||||
{
|
||||
clear();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -581,7 +582,7 @@ namespace dlib
|
|||
// ------------------------------------------------------------------------------------
|
||||
|
||||
max_pool::max_pool (
|
||||
) : handle(nullptr),stride_y(0),stride_x(0)
|
||||
) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -598,31 +599,52 @@ namespace dlib
|
|||
if (handle)
|
||||
cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
|
||||
handle = nullptr;
|
||||
window_height = 0;
|
||||
window_width = 0;
|
||||
stride_y = 0;
|
||||
stride_x = 0;
|
||||
}
|
||||
|
||||
void max_pool::
|
||||
setup(
|
||||
int window_height,
|
||||
int window_width,
|
||||
int window_height_,
|
||||
int window_width_,
|
||||
int stride_y_,
|
||||
int stride_x_
|
||||
)
|
||||
{
|
||||
stride_x = stride_x_;
|
||||
stride_y = stride_y_;
|
||||
cudnnPoolingDescriptor_t poolingDesc;
|
||||
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
|
||||
handle = poolingDesc;
|
||||
if (window_height == window_height_ &&
|
||||
window_width == window_width_ &&
|
||||
stride_y == stride_y_ &&
|
||||
stride_x == stride_x_ )
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
|
||||
CUDNN_POOLING_MAX,
|
||||
window_height,
|
||||
window_width,
|
||||
0,0, // no padding
|
||||
stride_y,
|
||||
stride_x));
|
||||
clear();
|
||||
try
|
||||
{
|
||||
window_height = window_height_;
|
||||
window_width = window_width_;
|
||||
stride_x = stride_x_;
|
||||
stride_y = stride_y_;
|
||||
cudnnPoolingDescriptor_t poolingDesc;
|
||||
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
|
||||
handle = poolingDesc;
|
||||
|
||||
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
|
||||
CUDNN_POOLING_MAX,
|
||||
window_height,
|
||||
window_width,
|
||||
0,0, // no padding
|
||||
stride_y,
|
||||
stride_x));
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
clear();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void max_pool::
|
||||
|
@ -649,8 +671,8 @@ namespace dlib
|
|||
|
||||
DLIB_CASSERT(dest.num_samples() == src.num_samples(),"");
|
||||
DLIB_CASSERT(dest.k() == src.k(),"");
|
||||
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,"");
|
||||
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,"");
|
||||
DLIB_CASSERT(dest.nr() == src.nr()/stride_y, stride_y << ", " << dest.nr() << " " << src.nr()/stride_y);
|
||||
DLIB_CASSERT(dest.nc() == src.nc()/stride_x, stride_x << ", " << dest.nc() << " " << src.nc()/stride_x);
|
||||
|
||||
CHECK_CUDNN(cudnnPoolingForward(context(),
|
||||
(const cudnnPoolingDescriptor_t)handle,
|
||||
|
@ -673,7 +695,7 @@ namespace dlib
|
|||
DLIB_CASSERT(have_same_dimensions(src,grad),"");
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
const float beta = 1;
|
||||
CHECK_CUDNN(cudnnPoolingBackward(context(),
|
||||
(const cudnnPoolingDescriptor_t)handle,
|
||||
&alpha,
|
||||
|
|
|
@ -328,6 +328,8 @@ namespace dlib
|
|||
|
||||
private:
|
||||
void* handle;
|
||||
int window_height;
|
||||
int window_width;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue