Add batch inference on C++ (#7915)

* Add batch inference on C++

* Return default params

* Add make_nms parameter
This commit is contained in:
Sergey Nuzhny 2021-07-18 18:58:01 +03:00 committed by GitHub
parent 5e73447fa8
commit d669680879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 8 deletions

View File

@ -56,7 +56,7 @@ struct bbox_t_container {
#include <opencv2/imgproc/imgproc_c.h> // C #include <opencv2/imgproc/imgproc_c.h> // C
#endif #endif
extern "C" LIB_API int init(const char *configurationFilename, const char *weightsFilename, int gpu); extern "C" LIB_API int init(const char *configurationFilename, const char *weightsFilename, int gpu, int batch_size);
extern "C" LIB_API int detect_image(const char *filename, bbox_t_container &container); extern "C" LIB_API int detect_image(const char *filename, bbox_t_container &container);
extern "C" LIB_API int detect_mat(const uint8_t* data, const size_t data_length, bbox_t_container &container); extern "C" LIB_API int detect_mat(const uint8_t* data, const size_t data_length, bbox_t_container &container);
extern "C" LIB_API int dispose(); extern "C" LIB_API int dispose();
@ -76,11 +76,12 @@ public:
float nms = .4; float nms = .4;
bool wait_stream; bool wait_stream;
LIB_API Detector(std::string cfg_filename, std::string weight_filename, int gpu_id = 0); LIB_API Detector(std::string cfg_filename, std::string weight_filename, int gpu_id = 0, int batch_size = 1);
LIB_API ~Detector(); LIB_API ~Detector();
LIB_API std::vector<bbox_t> detect(std::string image_filename, float thresh = 0.2, bool use_mean = false); LIB_API std::vector<bbox_t> detect(std::string image_filename, float thresh = 0.2, bool use_mean = false);
LIB_API std::vector<bbox_t> detect(image_t img, float thresh = 0.2, bool use_mean = false); LIB_API std::vector<bbox_t> detect(image_t img, float thresh = 0.2, bool use_mean = false);
LIB_API std::vector<std::vector<bbox_t>> detectBatch(image_t img, int batch_size, int width, int height, float thresh, bool make_nms = true);
static LIB_API image_t load_image(std::string image_filename); static LIB_API image_t load_image(std::string image_filename);
static LIB_API void free_image(image_t m); static LIB_API void free_image(image_t m);
LIB_API int get_net_width() const; LIB_API int get_net_width() const;

View File

@ -27,9 +27,9 @@ extern "C" {
//static Detector* detector = NULL; //static Detector* detector = NULL;
static std::unique_ptr<Detector> detector; static std::unique_ptr<Detector> detector;
int init(const char *configurationFilename, const char *weightsFilename, int gpu) int init(const char *configurationFilename, const char *weightsFilename, int gpu, int batch_size)
{ {
detector.reset(new Detector(configurationFilename, weightsFilename, gpu)); detector.reset(new Detector(configurationFilename, weightsFilename, gpu, batch_size));
return 1; return 1;
} }
@ -127,7 +127,8 @@ struct detector_gpu_t {
unsigned int *track_id; unsigned int *track_id;
}; };
LIB_API Detector::Detector(std::string cfg_filename, std::string weight_filename, int gpu_id) : cur_gpu_id(gpu_id) LIB_API Detector::Detector(std::string cfg_filename, std::string weight_filename, int gpu_id, int batch_size)
: cur_gpu_id(gpu_id)
{ {
wait_stream = 0; wait_stream = 0;
#ifdef GPU #ifdef GPU
@ -153,11 +154,11 @@ LIB_API Detector::Detector(std::string cfg_filename, std::string weight_filename
char *cfgfile = const_cast<char *>(_cfg_filename.c_str()); char *cfgfile = const_cast<char *>(_cfg_filename.c_str());
char *weightfile = const_cast<char *>(_weight_filename.c_str()); char *weightfile = const_cast<char *>(_weight_filename.c_str());
net = parse_network_cfg_custom(cfgfile, 1, 1); net = parse_network_cfg_custom(cfgfile, batch_size, batch_size);
if (weightfile) { if (weightfile) {
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
set_batch_network(&net, 1); set_batch_network(&net, batch_size);
net.gpu_index = cur_gpu_id; net.gpu_index = cur_gpu_id;
fuse_conv_batchnorm(net); fuse_conv_batchnorm(net);
@ -354,6 +355,74 @@ LIB_API std::vector<bbox_t> Detector::detect(image_t img, float thresh, bool use
return bbox_vec; return bbox_vec;
} }
LIB_API std::vector<std::vector<bbox_t>> Detector::detectBatch(image_t img, int batch_size, int width, int height, float thresh, bool make_nms)
{
detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());
network &net = detector_gpu.net;
#ifdef GPU
int old_gpu_index;
cudaGetDevice(&old_gpu_index);
if(cur_gpu_id != old_gpu_index)
cudaSetDevice(net.gpu_index);
net.wait_stream = wait_stream; // 1 - wait CUDA-stream, 0 - not to wait
#endif
//std::cout << "net.gpu_index = " << net.gpu_index << std::endl;
layer l = net.layers[net.n - 1];
float hier_thresh = 0.5;
image in_img;
in_img.c = img.c;
in_img.w = img.w;
in_img.h = img.h;
in_img.data = img.data;
det_num_pair* prediction = network_predict_batch(&net, in_img, batch_size, width, height, thresh, hier_thresh, 0, 0, 0);
std::vector<std::vector<bbox_t>> bbox_vec(batch_size);
for (int bi = 0; bi < batch_size; ++bi)
{
auto dets = prediction[bi].dets;
if (make_nms && nms)
do_nms_sort(dets, prediction[bi].num, l.classes, nms);
for (int i = 0; i < prediction[bi].num; ++i)
{
box b = dets[i].bbox;
int const obj_id = max_index(dets[i].prob, l.classes);
float const prob = dets[i].prob[obj_id];
if (prob > thresh)
{
bbox_t bbox;
bbox.x = std::max((double)0, (b.x - b.w / 2.));
bbox.y = std::max((double)0, (b.y - b.h / 2.));
bbox.w = b.w;
bbox.h = b.h;
bbox.obj_id = obj_id;
bbox.prob = prob;
bbox.track_id = 0;
bbox.frames_counter = 0;
bbox.x_3d = NAN;
bbox.y_3d = NAN;
bbox.z_3d = NAN;
bbox_vec[bi].push_back(bbox);
}
}
}
free_batch_detections(prediction, batch_size);
#ifdef GPU
if (cur_gpu_id != old_gpu_index)
cudaSetDevice(old_gpu_index);
#endif
return bbox_vec;
}
LIB_API std::vector<bbox_t> Detector::tracking_id(std::vector<bbox_t> cur_bbox_vec, bool const change_history, LIB_API std::vector<bbox_t> Detector::tracking_id(std::vector<bbox_t> cur_bbox_vec, bool const change_history,
int const frames_story, int const max_dist) int const frames_story, int const max_dist)
{ {