mirror of https://github.com/AlexeyAB/darknet.git
361 lines
11 KiB
Python
361 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
Python 3 wrapper for identifying objects in images
|
|
|
|
Running the script requires opencv-python to be installed (`pip install opencv-python`)
|
|
Directly viewing or returning bounding-boxed images requires scikit-image to be installed (`pip install scikit-image`)
|
|
Use pip3 instead of pip on some systems to be sure to install modules for python3
|
|
"""
|
|
|
|
import ctypes as ct
|
|
import random
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
class BOX(ct.Structure):
|
|
_fields_ = (
|
|
("x", ct.c_float),
|
|
("y", ct.c_float),
|
|
("w", ct.c_float),
|
|
("h", ct.c_float),
|
|
)
|
|
|
|
|
|
FloatPtr = ct.POINTER(ct.c_float)
|
|
IntPtr = ct.POINTER(ct.c_int)
|
|
|
|
|
|
class DETECTION(ct.Structure):
|
|
_fields_ = (
|
|
("bbox", BOX),
|
|
("classes", ct.c_int),
|
|
("best_class_idx", ct.c_int),
|
|
("prob", FloatPtr),
|
|
("mask", FloatPtr),
|
|
("objectness", ct.c_float),
|
|
("sort_class", ct.c_int),
|
|
("uc", FloatPtr),
|
|
("points", ct.c_int),
|
|
("embeddings", FloatPtr),
|
|
("embedding_size", ct.c_int),
|
|
("sim", ct.c_float),
|
|
("track_id", ct.c_int),
|
|
)
|
|
|
|
|
|
DETECTIONPtr = ct.POINTER(DETECTION)
|
|
|
|
|
|
class DETNUMPAIR(ct.Structure):
|
|
_fields_ = (
|
|
("num", ct.c_int),
|
|
("dets", DETECTIONPtr),
|
|
)
|
|
|
|
|
|
DETNUMPAIRPtr = ct.POINTER(DETNUMPAIR)
|
|
|
|
|
|
class IMAGE(ct.Structure):
|
|
_fields_ = (
|
|
("w", ct.c_int),
|
|
("h", ct.c_int),
|
|
("c", ct.c_int),
|
|
("data", FloatPtr),
|
|
)
|
|
|
|
|
|
class METADATA(ct.Structure):
|
|
_fields_ = (
|
|
("classes", ct.c_int),
|
|
("names", ct.POINTER(ct.c_char_p)),
|
|
)
|
|
|
|
|
|
def network_width(net):
|
|
return lib.network_width(net)
|
|
|
|
|
|
def network_height(net):
|
|
return lib.network_height(net)
|
|
|
|
|
|
def bbox2points(bbox):
|
|
"""
|
|
From bounding box yolo format
|
|
to corner points cv2 rectangle
|
|
"""
|
|
x, y, w, h = bbox
|
|
xmin = round(x - (w / 2))
|
|
xmax = round(x + (w / 2))
|
|
ymin = round(y - (h / 2))
|
|
ymax = round(y + (h / 2))
|
|
return xmin, ymin, xmax, ymax
|
|
|
|
|
|
def class_colors(names):
|
|
"""
|
|
Create a dict with one random BGR color for each
|
|
class name
|
|
"""
|
|
return {name: (
|
|
random.randint(0, 255),
|
|
random.randint(0, 255),
|
|
random.randint(0, 255)) for name in names}
|
|
|
|
|
|
def load_network(config_file, data_file, weights, batch_size=1):
|
|
"""
|
|
load model description and weights from config files
|
|
args:
|
|
config_file (str): path to .cfg model file
|
|
data_file (str): path to .data model file
|
|
weights (str): path to weights
|
|
returns:
|
|
network: trained model
|
|
class_names
|
|
class_colors
|
|
"""
|
|
network = load_net_custom(
|
|
config_file.encode("ascii"),
|
|
weights.encode("ascii"), 0, batch_size)
|
|
metadata = load_meta(data_file.encode("ascii"))
|
|
class_names = [metadata.names[i].decode("ascii") for i in range(metadata.classes)]
|
|
colors = class_colors(class_names)
|
|
return network, class_names, colors
|
|
|
|
|
|
def print_detections(detections, coordinates=False):
|
|
print("\nObjects:")
|
|
for label, confidence, bbox in detections:
|
|
x, y, w, h = bbox
|
|
if coordinates:
|
|
print("{}: {}% (left_x: {:.0f} top_y: {:.0f} width: {:.0f} height: {:.0f})".format(label, confidence, x, y, w, h))
|
|
else:
|
|
print("{}: {}%".format(label, confidence))
|
|
|
|
|
|
def draw_boxes(detections, image, colors):
|
|
import cv2
|
|
for label, confidence, bbox in detections:
|
|
left, top, right, bottom = bbox2points(bbox)
|
|
cv2.rectangle(image, (left, top), (right, bottom), colors[label], 1)
|
|
cv2.putText(image, "{} [{:.2f}]".format(label, float(confidence)),
|
|
(left, top - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
colors[label], 2)
|
|
return image
|
|
|
|
|
|
def decode_detection(detections):
|
|
decoded = []
|
|
for label, confidence, bbox in detections:
|
|
confidence = str(round(confidence * 100, 2))
|
|
decoded.append((str(label), confidence, bbox))
|
|
return decoded
|
|
|
|
|
|
# https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
|
|
# Malisiewicz et al.
|
|
def non_max_suppression_fast(detections, overlap_thresh):
|
|
boxes = []
|
|
for detection in detections:
|
|
_, _, _, (x, y, w, h) = detection
|
|
x1 = x - w / 2
|
|
y1 = y - h / 2
|
|
x2 = x + w / 2
|
|
y2 = y + h / 2
|
|
boxes.append(np.array([x1, y1, x2, y2]))
|
|
boxes_array = np.array(boxes)
|
|
|
|
# initialize the list of picked indexes
|
|
pick = []
|
|
# grab the coordinates of the bounding boxes
|
|
x1 = boxes_array[:, 0]
|
|
y1 = boxes_array[:, 1]
|
|
x2 = boxes_array[:, 2]
|
|
y2 = boxes_array[:, 3]
|
|
# compute the area of the bounding boxes and sort the bounding
|
|
# boxes by the bottom-right y-coordinate of the bounding box
|
|
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
idxs = np.argsort(y2)
|
|
# keep looping while some indexes still remain in the indexes
|
|
# list
|
|
while len(idxs) > 0:
|
|
# grab the last index in the indexes list and add the
|
|
# index value to the list of picked indexes
|
|
last = len(idxs) - 1
|
|
i = idxs[last]
|
|
pick.append(i)
|
|
# find the largest (x, y) coordinates for the start of
|
|
# the bounding box and the smallest (x, y) coordinates
|
|
# for the end of the bounding box
|
|
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
|
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
|
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
|
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
|
# compute the width and height of the bounding box
|
|
w = np.maximum(0, xx2 - xx1 + 1)
|
|
h = np.maximum(0, yy2 - yy1 + 1)
|
|
# compute the ratio of overlap
|
|
overlap = (w * h) / area[idxs[:last]]
|
|
# delete all indexes from the index list that have
|
|
idxs = np.delete(idxs, np.concatenate(([last],
|
|
np.where(overlap > overlap_thresh)[0])))
|
|
# return only the bounding boxes that were picked using the
|
|
# integer data type
|
|
return [detections[i] for i in pick]
|
|
|
|
|
|
def remove_negatives(detections, class_names, num):
|
|
"""
|
|
Remove all classes with 0% confidence within the detection
|
|
"""
|
|
predictions = []
|
|
for j in range(num):
|
|
for idx, name in enumerate(class_names):
|
|
if detections[j].prob[idx] > 0:
|
|
bbox = detections[j].bbox
|
|
bbox = (bbox.x, bbox.y, bbox.w, bbox.h)
|
|
predictions.append((name, detections[j].prob[idx], (bbox)))
|
|
return predictions
|
|
|
|
|
|
def remove_negatives_faster(detections, class_names, num):
|
|
"""
|
|
Faster version of remove_negatives (very useful when using yolo9000)
|
|
"""
|
|
predictions = []
|
|
for j in range(num):
|
|
if detections[j].best_class_idx == -1:
|
|
continue
|
|
name = class_names[detections[j].best_class_idx]
|
|
bbox = detections[j].bbox
|
|
bbox = (bbox.x, bbox.y, bbox.w, bbox.h)
|
|
predictions.append((name, detections[j].prob[detections[j].best_class_idx], bbox))
|
|
return predictions
|
|
|
|
|
|
def detect_image(network, class_names, image, thresh=.5, hier_thresh=.5, nms=.45):
|
|
"""
|
|
Returns a list with highest confidence class and their bbox
|
|
"""
|
|
pnum = ct.pointer(ct.c_int(0))
|
|
predict_image(network, image)
|
|
detections = get_network_boxes(network, image.w, image.h,
|
|
thresh, hier_thresh, None, 0, pnum, 0)
|
|
num = pnum[0]
|
|
if nms:
|
|
do_nms_sort(detections, num, len(class_names), nms)
|
|
predictions = remove_negatives(detections, class_names, num)
|
|
predictions = decode_detection(predictions)
|
|
free_detections(detections, num)
|
|
return sorted(predictions, key=lambda x: x[1])
|
|
|
|
|
|
if os.name == "posix":
|
|
cwd = os.path.dirname(__file__)
|
|
lib = ct.CDLL(cwd + "/libdarknet.so", ct.RTLD_GLOBAL)
|
|
elif os.name == "nt":
|
|
cwd = os.path.dirname(__file__)
|
|
os.environ["PATH"] = os.path.pathsep.join((cwd, os.environ["PATH"]))
|
|
lib = ct.CDLL("darknet.dll", winmode = 0, mode = ct.RTLD_GLOBAL)
|
|
else:
|
|
lib = None # Intellisense
|
|
print("Unsupported OS")
|
|
exit()
|
|
|
|
lib.network_width.argtypes = (ct.c_void_p,)
|
|
lib.network_width.restype = ct.c_int
|
|
lib.network_height.argtypes = (ct.c_void_p,)
|
|
lib.network_height.restype = ct.c_int
|
|
|
|
copy_image_from_bytes = lib.copy_image_from_bytes
|
|
copy_image_from_bytes.argtypes = (IMAGE, ct.c_char_p)
|
|
|
|
predict = lib.network_predict_ptr
|
|
predict.argtypes = (ct.c_void_p, FloatPtr)
|
|
predict.restype = FloatPtr
|
|
|
|
set_gpu = lib.cuda_set_device
|
|
init_cpu = lib.init_cpu
|
|
|
|
make_image = lib.make_image
|
|
make_image.argtypes = (ct.c_int, ct.c_int, ct.c_int)
|
|
make_image.restype = IMAGE
|
|
|
|
get_network_boxes = lib.get_network_boxes
|
|
get_network_boxes.argtypes = (ct.c_void_p, ct.c_int, ct.c_int, ct.c_float, ct.c_float, IntPtr, ct.c_int, IntPtr,
|
|
ct.c_int)
|
|
get_network_boxes.restype = DETECTIONPtr
|
|
|
|
make_network_boxes = lib.make_network_boxes
|
|
make_network_boxes.argtypes = (ct.c_void_p,)
|
|
make_network_boxes.restype = DETECTIONPtr
|
|
|
|
free_detections = lib.free_detections
|
|
free_detections.argtypes = (DETECTIONPtr, ct.c_int)
|
|
|
|
free_batch_detections = lib.free_batch_detections
|
|
free_batch_detections.argtypes = (DETNUMPAIRPtr, ct.c_int)
|
|
|
|
free_ptrs = lib.free_ptrs
|
|
free_ptrs.argtypes = (ct.POINTER(ct.c_void_p), ct.c_int)
|
|
|
|
network_predict = lib.network_predict_ptr
|
|
network_predict.argtypes = (ct.c_void_p, FloatPtr)
|
|
|
|
reset_rnn = lib.reset_rnn
|
|
reset_rnn.argtypes = (ct.c_void_p,)
|
|
|
|
load_net = lib.load_network
|
|
load_net.argtypes = (ct.c_char_p, ct.c_char_p, ct.c_int)
|
|
load_net.restype = ct.c_void_p
|
|
|
|
load_net_custom = lib.load_network_custom
|
|
load_net_custom.argtypes = (ct.c_char_p, ct.c_char_p, ct.c_int, ct.c_int)
|
|
load_net_custom.restype = ct.c_void_p
|
|
|
|
free_network_ptr = lib.free_network_ptr
|
|
free_network_ptr.argtypes = (ct.c_void_p,)
|
|
free_network_ptr.restype = ct.c_void_p
|
|
|
|
do_nms_obj = lib.do_nms_obj
|
|
do_nms_obj.argtypes = (DETECTIONPtr, ct.c_int, ct.c_int, ct.c_float)
|
|
|
|
do_nms_sort = lib.do_nms_sort
|
|
do_nms_sort.argtypes = (DETECTIONPtr, ct.c_int, ct.c_int, ct.c_float)
|
|
|
|
free_image = lib.free_image
|
|
free_image.argtypes = (IMAGE,)
|
|
|
|
letterbox_image = lib.letterbox_image
|
|
letterbox_image.argtypes = (IMAGE, ct.c_int, ct.c_int)
|
|
letterbox_image.restype = IMAGE
|
|
|
|
load_meta = lib.get_metadata
|
|
lib.get_metadata.argtypes = (ct.c_char_p,)
|
|
lib.get_metadata.restype = METADATA
|
|
|
|
load_image = lib.load_image_color
|
|
load_image.argtypes = (ct.c_char_p, ct.c_int, ct.c_int)
|
|
load_image.restype = IMAGE
|
|
|
|
rgbgr_image = lib.rgbgr_image
|
|
rgbgr_image.argtypes = (IMAGE,)
|
|
|
|
predict_image = lib.network_predict_image
|
|
predict_image.argtypes = (ct.c_void_p, IMAGE)
|
|
predict_image.restype = FloatPtr
|
|
|
|
predict_image_letterbox = lib.network_predict_image_letterbox
|
|
predict_image_letterbox.argtypes = (ct.c_void_p, IMAGE)
|
|
predict_image_letterbox.restype = FloatPtr
|
|
|
|
network_predict_batch = lib.network_predict_batch
|
|
network_predict_batch.argtypes = (ct.c_void_p, IMAGE, ct.c_int, ct.c_int, ct.c_int,
|
|
ct.c_float, ct.c_float, IntPtr, ct.c_int, ct.c_int)
|
|
network_predict_batch.restype = DETNUMPAIRPtr
|