Merge pull request #2352 from aughey/master

Changes to better support python bindings
This commit is contained in:
Alexey 2019-02-06 14:23:26 +03:00 committed by GitHub
commit b76f1c0006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 5 deletions

View File

@ -125,7 +125,7 @@ lib.network_width.restype = c_int
lib.network_height.argtypes = [c_void_p]
lib.network_height.restype = c_int
predict = lib.network_predict
predict = lib.network_predict_ptr
predict.argtypes = [c_void_p, POINTER(c_float)]
predict.restype = POINTER(c_float)

View File

@ -2137,3 +2137,22 @@ void free_image(image m)
free(m.data);
}
}
// Fast copy data from a contiguous byte array into the image.
LIB_API void copy_image_from_bytes(image im, char *pdata)
{
unsigned char *data = (unsigned char*)pdata;
int i, k, j;
int w = im.w;
int h = im.h;
int c = im.c;
for (k = 0; k < c; ++k) {
for (j = 0; j < h; ++j) {
for (i = 0; i < w; ++i) {
int dst_index = i + w * j + w * h*k;
int src_index = k + c * i + c * w*j;
im.data[dst_index] = (float)data[src_index] / 255.;
}
}
}
}

View File

@ -33,6 +33,7 @@ image random_crop_image(image im, int w, int h);
image random_augment_image(image im, float angle, float aspect, int low, int high, int size);
void random_distort_image(image im, float hue, float saturation, float exposure);
//LIB_API image resize_image(image im, int w, int h);
LIB_API void copy_image_from_bytes(image im, char *pdata);
void fill_image(image m, float s);
void letterbox_image_into(image im, int w, int h, image boxed);
//LIB_API image letterbox_image(image im, int w, int h);

View File

@ -556,6 +556,12 @@ void top_predictions(network net, int k, int *index)
top_k(out, size, k, index);
}
// A version of network_predict that uses a pointer for the network
// struct to make the python binding work properly.
float *network_predict_ptr(network *net, float *input)
{
return network_predict(*net, input);
}
float *network_predict(network net, float *input)
{
@ -731,10 +737,17 @@ char *detection_to_json(detection *dets, int nboxes, int classes, char **names,
float *network_predict_image(network *net, image im)
{
//image imr = letterbox_image(im, net->w, net->h);
image imr = resize_image(im, net->w, net->h);
set_batch_network(net, 1);
float *p = network_predict(*net, imr.data);
free_image(imr);
float *p;
if (im.w == net->w && im.h == net->h) {
// Input image is the same size as our net, predict on that image
p = network_predict(*net, im.data);
}
else {
// Need to resize image to the desired size for the net
image imr = resize_image(im, net->w, net->h);
p = network_predict(*net, imr.data);
free_image(imr);
}
return p;
}

View File

@ -122,6 +122,7 @@ float train_network_datum(network net, float *x, float *y);
matrix network_predict_data(network net, data test);
//LIB_API float *network_predict(network net, float *input);
LIB_API float *network_predict_ptr(network *net, float *input);
float network_accuracy(network net, data d);
float *network_accuracies(network net, data d, int n);
float network_accuracy_multi(network net, data d, int n);