aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/object_detection/yolo.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/object_detection/yolo.py')
-rw-r--r--python/pyarmnn/examples/object_detection/yolo.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/python/pyarmnn/examples/object_detection/yolo.py b/python/pyarmnn/examples/object_detection/yolo.py
index 1748d158a2..e76ed7b2f4 100644
--- a/python/pyarmnn/examples/object_detection/yolo.py
+++ b/python/pyarmnn/examples/object_detection/yolo.py
@@ -80,19 +80,19 @@ def yolo_processing(output: np.ndarray, confidence_threshold=0.40, iou_threshold
return nms_det
-def yolo_resize_factor(video: cv2.VideoCapture, input_binding_info: tuple):
+def yolo_resize_factor(video: cv2.VideoCapture, input_data_shape: tuple):
"""
Gets a multiplier to scale the bounding box positions to
their correct position in the frame.
Args:
video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer.
+ input_data_shape: Contains shape of model input layer.
Returns:
Resizing factor to scale box coordinates to output frame size.
"""
frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
- model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+ _, model_height, model_width, _= input_data_shape
return max(frame_height, frame_width) / max(model_height, model_width)