aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJakub Sujak <jakub.sujak@arm.com>2020-06-17 15:35:03 +0100
committerJakub Sujak <jakub.sujak@arm.com>2020-08-24 09:18:14 +0000
commit433a59567ccf7fd6fdbbd1227eac3778876e8bd9 (patch)
tree3321c3e5fed27955a7259210968c325f8886c1f2 /python
parentfc5d5c21040556498bd7330f560811e42fc1a11b (diff)
downloadarmnn-433a59567ccf7fd6fdbbd1227eac3778876e8bd9.tar.gz
MLECO-955: Added python object detection example for PyArmNN
Change-Id: I1344c027f4cc70520b7846b34dfbc2abf399d10a Signed-off-by: Jakub Sujak <jakub.sujak@arm.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyarmnn/examples/object_detection/README.md186
-rw-r--r--python/pyarmnn/examples/object_detection/requirements.txt3
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_file.py106
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_stream.py102
-rw-r--r--python/pyarmnn/examples/object_detection/ssd.py53
-rw-r--r--python/pyarmnn/examples/object_detection/ssd_labels.txt91
-rw-r--r--python/pyarmnn/examples/object_detection/utils.py231
-rw-r--r--python/pyarmnn/examples/object_detection/yolo.py98
-rw-r--r--python/pyarmnn/examples/object_detection/yolo_labels.txt80
9 files changed, 950 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md
new file mode 100644
index 0000000000..5d401630ad
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/README.md
@@ -0,0 +1,186 @@
+# PyArmNN Object Detection Sample Application
+
+## Introduction
+This sample application guides the user and shows how to perform object detection using PyArmNN API. We assume the user has already built PyArmNN by following the instructions of the README in the main PyArmNN directory.
+
+We provide example scripts for performing object detection from video file and video stream with `run_video_file.py` and `run_video_stream.py`.
+
+The application takes a model and video file or camera feed as input, runs inference on each frame, and draws bounding boxes around detected objects, with the corresponding labels and confidence scores overlaid.
+
+A similar implementation of this object detection application is also provided in C++ in the examples for ArmNN.
+
+## Prerequisites
+
+##### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'22.0.0'
+```
+
+##### Dependencies
+
+Install the following libraries on your system:
+```bash
+$ sudo apt-get install python3-opencv libqtgui4 libqt4-test
+```
+
+Create a virtual environment:
+```bash
+$ python3.7 -m venv devenv --system-site-packages
+$ source devenv/bin/activate
+```
+
+Install the dependencies:
+```bash
+$ pip install -r requirements.txt
+```
+
+---
+
+# Performing Object Detection
+
+## Object Detection from Video File
+The `run_video_file.py` example takes a video file as input, runs inference on each frame, and produces frames with bounding boxes drawn around detected objects. The processed frames are written to video file.
+
+The user can specify these arguments at command line:
+
+* `--video_file_path` - <b>Required:</b> Path to the video file to run object detection on
+* `--model_file_path` - <b>Required:</b> Path to <b>.tflite, .pb</b> or <b>.onnx</b> object detection model
+* `--model_name` - <b>Required:</b> The name of the model being used. Assembles the workflow for the input model. The examples support the model names:
+ * `ssd_mobilenet_v1`
+ * `yolo_v3_tiny`
+* `--label_path` - Path to labels file for the specified model file. Labels are provided for above model names
+* `--output_video_file_path` - Path to the output video file with detections added in
+* `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']`
+
+
+Run the sample script:
+```bash
+$ python run_video_file.py --video_file_path <video_file_path> --model_file_path <model_file_path> --model_name <model_name>
+```
+
+## Object Detection from Video Stream
+The `run_video_stream.py` example captures frames from a video stream of a device, runs inference on each frame, and produces frames with bounding boxes drawn around detected objects. A window is displayed and refreshed with the latest processed frame.
+
+The user can specify these arguments at command line:
+
+* `--video_source` - Device index to access video stream. Defaults to primary device camera at index 0
+* `--model_file_path` - <b>Required:</b> Path to <b>.tflite, .pb</b> or <b>.onnx</b> object detection model
+* `--model_name` - <b>Required:</b> The name of the model being used. Assembles the workflow for the input model. The examples support the model names:
+ * `ssd_mobilenet_v1`
+ * `yolo_v3_tiny`
+* `--label_path` - Path to labels file for the specified model file. Labels are provided for above model names
+* `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']`
+
+
+Run the sample script:
+```bash
+$ python run_video_stream.py --model_file_path <model_file_path> --model_name <model_name>
+```
+
+## Implementing Your Own Network
+The examples provide support for `ssd_mobilenet_v1` and `yolo_v3_tiny` models. However, the user is able to add their own network to the object detection scripts by following the steps:
+
+1. Create a new file for your network, for example `network.py`, to contain functions to process the output of the model
+2. In that file, the user will need to write a function that decodes the output vectors obtained from running inference on their network and return the bounding box positions of detected objects plus their class index and confidence. Additionally, include a function that returns a resize factor that will scale the obtained bounding boxes to their correct positions in the original frame
+3. Import the functions into the main file and, such as with the provided networks, add a conditional statement to the `get_model_processing()` function with the new model name and functions
+4. The labels associated with the model can then either be included inside the conditional statement or passed in with `--label_path` argument when executing the main script
+
+---
+
+# Application Overview
+This section provides a walkthrough of the application, explaining in detail the steps:
+1. Initialisation
+ 1.1. Reading from Video Source
+ 1.2. Preparing Labels and Model Specific Functions
+2. Creating a Network
+ 2.1. Creating Parser and Importing Graph
+ 2.2. Optimizing Graph for Compute Device
+ 2.3. Creating Input and Output Binding Information
+3. Preparing the Workload Tensors
+ 3.1. Preprocessing the Captured Frame
+ 3.2. Making Input and Output Tensors
+4. Executing Inference
+5. Postprocessing
+ 5.1. Decoding and Processing Inference Output
+ 5.2. Drawing Bounding Boxes
+
+
+### Initialisation
+
+##### Reading from Video Source
+After parsing user arguments, the chosen video file or stream is loaded into an OpenCV `cv2.VideoCapture()` object. We use this object to capture frames from the source using the `read()` function.
+
+The `VideoCapture` object also tells us information about the source, such as the framerate and resolution of the input video. Using this information, we create a `cv2.VideoWriter()` object which will be used at the end of every loop to write the processed frame to an output video file of the same format as the input.
+
+##### Preparing Labels and Model Specific Functions
+In order to interpret the result of running inference on the loaded network, it is required to load the labels associated with the model. In the provided example code, the `dict_labels()` function creates a dictionary that is keyed on the classification index at the output node of the model, with values of the dictionary corresponding to a label and a randomly generated RGB color. This ensures that each class has a unique color which will prove helpful when plotting the bounding boxes of various detected objects in a frame.
+
+Depending on the model being used, the user-specified model name accesses and returns functions to decode and process the inference output, along with a resize factor used when plotting bounding boxes to ensure they are scaled to their correct position in the original frame.
+
+
+### Creating a Network
+
+##### Creating Parser and Importing Graph
+The first step with PyArmNN is to import a graph from file by using the appropriate parser.
+
+The Arm NN SDK provides parsers for reading graphs from a variety of model formats. In our application we specifically focus on `.tflite, .pb, .onnx` models.
+
+Based on the extension of the provided model file, the corresponding parser is created and the network file loaded with `CreateNetworkFromBinaryFile()` function. The parser will handle the creation of the underlying Arm NN graph.
+
+##### Optimizing Graph for Compute Device
+Arm NN supports optimized execution on multiple CPU and GPU devices. Prior to executing a graph, we must select the appropriate device context. We do this by creating a runtime context with default options with `IRuntime()`.
+
+We can optimize the imported graph by specifying a list of backends in order of preference and implement backend-specific optimizations. The backends are identified by a string unique to the backend, for example `CpuAcc, GpuAcc, CpuRef`.
+
+Internally and transparently, Arm NN splits the graph into subgraph based on backends, it calls a optimize subgraphs function on each of them and, if possible, substitutes the corresponding subgraph in the original graph with its optimized version.
+
+Using the `Optimize()` function we optimize the graph for inference and load the optimized network onto the compute device with `LoadNetwork()`. This function creates the backend-specific workloads for the layers and a backend specific workload factory which is called to create the workloads.
+
+##### Creating Input and Output Binding Information
+Parsers can also be used to extract the input information for the network. By calling `GetSubgraphInputTensorNames` we extract all the input names and, with `GetNetworkInputBindingInfo`, bind the input points of the graph.
+
+The input binding information contains all the essential information about the input. It is a tuple consisting of integer identifiers for bindable layers (inputs, outputs) and the tensor info (data type, quantization information, number of dimensions, total number of elements).
+
+Similarly, we can get the output binding information for an output layer by using the parser to retrieve output tensor names and calling `GetNetworkOutputBindingInfo()`.
+
+
+### Preparing the Workload Tensors
+
+##### Preprocessing the Captured Frame
+Each frame captured from source is read as an `ndarray` in BGR format and therefore has to be preprocessed before being passed into the network.
+
+This preprocessing step consists of swapping channels (BGR to RGB in this example), resizing the frame to the required resolution, expanding dimensions of the array and doing data type conversion to match the model input layer. This information about the input tensor can be readily obtained from reading the `input_binding_info`. For example, SSD MobileNet V1 takes for input a tensor with shape `[1, 300, 300, 3]` and data type `uint8`.
+
+##### Making Input and Output Tensors
+To produce the workload tensors, calling the functions `make_input_tensors()` and `make_output_tensors()` will return the input and output tensors respectively.
+
+
+### Executing Inference
+After making the workload tensors, a compute device performs inference for the loaded network using the `EnqueueWorkload()` function of the runtime context. By calling the `workload_tensors_to_ndarray()` function, we obtain the results from inference as a list of `ndarrays`.
+
+
+### Postprocessing
+
+##### Decoding and Processing Inference Output
+The output from inference must be decoded to obtain information about detected objects in the frame. In the examples there are implementations for two networks but you may also implement your own network decoding solution here. Please refer to <i>Implementing Your Own Network</i> section of this document to learn how to do this.
+
+For SSD MobileNet V1 models, we decode the results to obtain the bounding box positions, classification index, confidence and number of detections in the input frame.
+
+For YOLO V3 Tiny models, we decode the output and perform non-maximum suppression to filter out any weak detections below a confidence threshold and any redudant bounding boxes above an intersection-over-union threshold.
+
+It is encouraged to experiment with threshold values for confidence and intersection-over-union (IoU) to achieve the best visual results.
+
+The detection results are always returned as a list in the form `[class index, [box positions], confidence score]`, with the box positions list containing bounding box coordinates in the form `[x_min, y_min, x_max, y_max]`.
+
+##### Drawing Bounding Boxes
+With the obtained results and using `draw_bounding_boxes()`, we are able to draw bounding boxes around detected objects and add the associated label and confidence score. The labels dictionary created earlier uses the class index of the detected object as a key to return the associated label and color for that class. The resize factor defined at the beginning scales the bounding box coordinates to their correct positions in the original frame. The processed frames are written to file or displayed in a separate window.
diff --git a/python/pyarmnn/examples/object_detection/requirements.txt b/python/pyarmnn/examples/object_detection/requirements.txt
new file mode 100644
index 0000000000..7cc6379eb9
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/requirements.txt
@@ -0,0 +1,3 @@
+argparse>=1.4.0
+numpy>=1.19.0
+tqdm>=4.47.0 \ No newline at end of file
diff --git a/python/pyarmnn/examples/object_detection/run_video_file.py b/python/pyarmnn/examples/object_detection/run_video_file.py
new file mode 100644
index 0000000000..4f06eb184d
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/run_video_file.py
@@ -0,0 +1,106 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Object detection demo that takes a video file, runs inference on each frame producing
+bounding boxes and labels around detected objects, and saves the processed video.
+"""
+
+import os
+import cv2
+import pyarmnn as ann
+from tqdm import tqdm
+from argparse import ArgumentParser
+
+from ssd import ssd_processing, ssd_resize_factor
+from yolo import yolo_processing, yolo_resize_factor
+from utils import create_video_writer, create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
+
+
+parser = ArgumentParser()
+parser.add_argument('--video_file_path', required=True, type=str,
+ help='Path to the video file to run object detection on')
+parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+parser.add_argument('--label_path', type=str,
+ help='Path to the labelset for the provided model file')
+parser.add_argument('--output_video_file_path', type=str,
+ help='Path to the output video file with detections added in')
+parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+args = parser.parse_args()
+
+
+def init_video(video_path: str, output_path: str):
+ """
+ Creates a video capture object from a video file.
+
+ Args:
+ video_path: User-specified video file path.
+ output_path: Optional path to save the processed video.
+
+ Returns:
+ Video capture object to capture frames, video writer object to write processed
+ frames to file, plus total frame count of video source to iterate through.
+ """
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f'Video file not found for: {video_path}')
+ video = cv2.VideoCapture(video_path)
+ if not video.isOpened:
+ raise RuntimeError(f'Failed to open video capture from file: {video_path}')
+
+ video_writer = create_video_writer(video, video_path, output_path)
+ iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
+ return video, video_writer, iter_frame_count
+
+
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+ """
+ Gets model-specific information such as model labels and decoding and processing functions.
+ The user can include their own network and functions by adding another statement.
+
+ Args:
+ model_name: Name of type of supported model.
+ video: Video capture object, contains information about data source.
+ input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+
+ Returns:
+ Model labels, decoding and processing functions.
+ """
+ if model_name == 'ssd_mobilenet_v1':
+ labels = os.path.join('ssd_labels.txt')
+ return labels, ssd_processing, ssd_resize_factor(video)
+ elif model_name == 'yolo_v3_tiny':
+ labels = os.path.join('yolo_labels.txt')
+ return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
+ else:
+ raise ValueError(f'{model_name} is not a valid model name')
+
+
+def main(args):
+ video, video_writer, frame_count = init_video(args.video_file_path, args.output_video_file_path)
+ net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
+ args.preferred_backends)
+ output_tensors = ann.make_output_tensors(output_binding_info)
+ labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
+ labels = dict_labels(labels if args.label_path is None else args.label_path)
+
+ for _ in tqdm(frame_count, desc='Processing frames'):
+ frame_present, frame = video.read()
+ if not frame_present:
+ continue
+ input_tensors = preprocess(frame, input_binding_info)
+ inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
+ detections = process_output(inference_output)
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
+ video_writer.write(frame)
+ print('Finished processing frames')
+ video.release(), video_writer.release()
+
+
+if __name__ == '__main__':
+ main(args)
diff --git a/python/pyarmnn/examples/object_detection/run_video_stream.py b/python/pyarmnn/examples/object_detection/run_video_stream.py
new file mode 100644
index 0000000000..94dc6c8b13
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/run_video_stream.py
@@ -0,0 +1,102 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Object detection demo that takes a video stream from a device, runs inference
+on each frame producing bounding boxes and labels around detected objects,
+and displays a window with the latest processed frame.
+"""
+
+import os
+import cv2
+import pyarmnn as ann
+from tqdm import tqdm
+from argparse import ArgumentParser
+
+from ssd import ssd_processing, ssd_resize_factor
+from yolo import yolo_processing, yolo_resize_factor
+from utils import create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
+
+
+parser = ArgumentParser()
+parser.add_argument('--video_source', type=int, default=0,
+ help='Device index to access video stream. Defaults to primary device camera at index 0')
+parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+parser.add_argument('--label_path', type=str,
+ help='Path to the labelset for the provided model file')
+parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+args = parser.parse_args()
+
+
+def init_video(video_source: int):
+ """
+ Creates a video capture object from a device.
+
+ Args:
+ video_source: Device index used to read video stream.
+
+ Returns:
+ Video capture object used to capture frames from a video stream.
+ """
+ video = cv2.VideoCapture(video_source)
+ if not video.isOpened:
+ raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
+ print('Processing video stream. Press \'Esc\' key to exit the demo.')
+ return video
+
+
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+ """
+ Gets model-specific information such as model labels and decoding and processing functions.
+ The user can include their own network and functions by adding another statement.
+
+ Args:
+ model_name: Name of type of supported model.
+ video: Video capture object, contains information about data source.
+ input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+
+ Returns:
+ Model labels, decoding and processing functions.
+ """
+ if model_name == 'ssd_mobilenet_v1':
+ labels = os.path.join('ssd_labels.txt')
+ return labels, ssd_processing, ssd_resize_factor(video)
+ elif model_name == 'yolo_v3_tiny':
+ labels = os.path.join('yolo_labels.txt')
+ return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
+ else:
+ raise ValueError(f'{model_name} is not a valid model name')
+
+
+def main(args):
+ video = init_video(args.video_source)
+ net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
+ args.preferred_backends)
+ output_tensors = ann.make_output_tensors(output_binding_info)
+ labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
+ labels = dict_labels(labels if args.label_path is None else args.label_path)
+
+ while True:
+ frame_present, frame = video.read()
+ frame = cv2.flip(frame, 1) # Horizontally flip the frame
+ if not frame_present:
+ raise RuntimeError('Error reading frame from video stream')
+ input_tensors = preprocess(frame, input_binding_info)
+ inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
+ detections = process_output(inference_output)
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
+ cv2.imshow('PyArmNN Object Detection Demo', frame)
+ if cv2.waitKey(1) == 27:
+ print('\nExit key activated. Closing video...')
+ break
+ video.release(), cv2.destroyAllWindows()
+
+
+if __name__ == '__main__':
+ main(args)
diff --git a/python/pyarmnn/examples/object_detection/ssd.py b/python/pyarmnn/examples/object_detection/ssd.py
new file mode 100644
index 0000000000..2016c4cbce
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/ssd.py
@@ -0,0 +1,53 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Contains functions specific to decoding and processing inference results for SSD Mobilenet V1 models.
+"""
+
+import cv2
+import numpy as np
+
+
+def ssd_processing(output: np.ndarray, confidence_threshold=0.60):
+ """
+ Gets class, bounding box positions and confidence from the four outputs of the SSD model.
+
+ Args:
+ output: Vector of outputs from network.
+ confidence_threshold: Selects only strong detections above this value.
+
+ Returns:
+ A list of detected objects in the form [class, [box positions], confidence]
+ """
+ if len(output) != 4:
+ raise RuntimeError('Number of outputs from SSD model does not equal 4')
+
+ position, classification, confidence, num_detections = [index[0] for index in output]
+
+ detections = []
+ for i in range(int(num_detections)):
+ if confidence[i] > confidence_threshold:
+ class_idx = classification[i]
+ box = position[i, :4]
+ # Reorder positions in format [x_min, y_min, x_max, y_max]
+ box[0], box[1], box[2], box[3] = box[1], box[0], box[3], box[2]
+ confidence_value = confidence[i]
+ detections.append((class_idx, box, confidence_value))
+ return detections
+
+
+def ssd_resize_factor(video: cv2.VideoCapture):
+ """
+ Gets a multiplier to scale the bounding box positions to
+ their correct position in the frame.
+
+ Args:
+ video: Video capture object, contains information about data source.
+
+ Returns:
+ Resizing factor to scale box coordinates to output frame size.
+ """
+ frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
+ return max(frame_height, frame_width)
diff --git a/python/pyarmnn/examples/object_detection/ssd_labels.txt b/python/pyarmnn/examples/object_detection/ssd_labels.txt
new file mode 100644
index 0000000000..5378c6cdad
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/ssd_labels.txt
@@ -0,0 +1,91 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+street sign
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+hat
+backpack
+umbrella
+shoe
+eye glasses
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+plate
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+mirror
+dining table
+window
+desk
+toilet
+door
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+blender
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
+hair brush \ No newline at end of file
diff --git a/python/pyarmnn/examples/object_detection/utils.py b/python/pyarmnn/examples/object_detection/utils.py
new file mode 100644
index 0000000000..1235bf4fa6
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/utils.py
@@ -0,0 +1,231 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+This file contains shared functions used in the object detection scripts for
+preprocessing data, preparing the network and postprocessing.
+"""
+
+import os
+import cv2
+import numpy as np
+import pyarmnn as ann
+
+
+def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
+ """
+ Creates a video writer object to write processed frames to file.
+
+ Args:
+ video: Video capture object, contains information about data source.
+ video_path: User-specified video file path.
+ output_path: Optional path to save the processed video.
+
+ Returns:
+ Video writer object.
+ """
+ _, ext = os.path.splitext(video_path)
+
+ if output_path is not None:
+ assert os.path.isdir(output_path)
+
+ i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
+ while os.path.exists(filename):
+ i += 1
+ filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
+
+ video_writer = cv2.VideoWriter(filename=filename,
+ fourcc=cv2.VideoWriter_fourcc(*'mp4v'),
+ fps=int(video.get(cv2.CAP_PROP_FPS)),
+ frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
+ int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
+ return video_writer
+
+
+def create_network(model_file: str, backends: list):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+
+ Returns:
+ net_id: Unique ID of the network to run.
+ runtime: Runtime context for executing inference.
+ input_binding_info: Contains essential information about the model input.
+ output_binding_info: Used to map output tensor and its memory.
+ """
+ if not os.path.exists(model_file):
+ raise FileNotFoundError(f'Model file not found for: {model_file}')
+
+ # Determine which parser to create based on model file extension
+ parser = None
+ _, ext = os.path.splitext(model_file)
+ if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+ elif ext == '.pb':
+ parser = ann.ITfParser()
+ elif ext == '.onnx':
+ parser = ann.IOnnxParser()
+ assert (parser is not None)
+ network = parser.CreateNetworkFromBinaryFile(model_file)
+
+ # Specify backends to optimize network
+ preferred_backends = []
+ for b in backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ # Select appropriate device context and optimize the network for that device
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
+ f'Optimization warnings: {messages}')
+
+ # Load the optimized network onto the Runtime device
+ net_id, _ = runtime.LoadNetwork(opt_network)
+
+ # Get input and output binding information
+ graph_id = parser.GetSubgraphCount() - 1
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ output_binding_info = []
+ for output_name in output_names:
+ outBindInfo = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+ output_binding_info.append(outBindInfo)
+ return net_id, runtime, input_binding_info, output_binding_info
+
+
+def dict_labels(labels_file: str):
+ """
+ Creates a labels dictionary from the input labels file.
+
+ Args:
+ labels_file: Default or user-specified file containing the model output labels.
+
+ Returns:
+ A dictionary keyed on the classification index with values corresponding to
+ labels and randomly generated RGB colors.
+ """
+ labels_dict = {}
+ with open(labels_file, 'r') as labels:
+ for index, line in enumerate(labels, 0):
+ labels_dict[index] = line.strip('\n'), tuple(np.random.random(size=3) * 255)
+ return labels_dict
+
+
+def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
+ """
+ Resizes frame while maintaining aspect ratio, padding any empty space.
+
+ Args:
+ frame: Captured frame.
+ input_binding_info: Contains shape of model input layer.
+
+ Returns:
+ Frame resized to the size of model input layer.
+ """
+ aspect_ratio = frame.shape[1] / frame.shape[0]
+ model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+
+ if aspect_ratio >= 1.0:
+ new_height, new_width = int(model_width / aspect_ratio), model_width
+ b_padding, r_padding = model_height - new_height, 0
+ else:
+ new_height, new_width = model_height, int(model_height * aspect_ratio)
+ b_padding, r_padding = 0, model_width - new_width
+
+ # Resize and pad any empty space
+ frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
+ frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
+ borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
+ return frame
+
+
+def preprocess(frame: np.ndarray, input_binding_info: tuple):
+ """
+ Takes a frame, resizes, swaps channels and converts data type to match
+ model input layer. The converted frame is wrapped in a const tensor
+ and bound to the input tensor.
+
+ Args:
+ frame: Captured frame from video.
+ input_binding_info: Contains shape and data type of model input layer.
+
+ Returns:
+ Input tensor.
+ """
+ # Swap channels and resize frame to model resolution
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
+
+ # Expand dimensions and convert data type to match model input
+ data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8
+ resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
+ assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
+
+ input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
+ return input_tensors
+
+
+def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> np.ndarray:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_tensors: The input frame tensor.
+ output_tensors: The output tensor from output node.
+ runtime: Runtime context for executing inference.
+ net_id: Unique ID of the network to run.
+
+ Returns:
+ Inference results as a list of ndarrays.
+ """
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+ output = ann.workload_tensors_to_ndarray(output_tensors)
+ return output
+
+
+def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
+ """
+ Draws bounding boxes around detected objects and adds a label and confidence score.
+
+ Args:
+ frame: The original captured frame from video source.
+ detections: A list of detected objects in the form [class, [box positions], confidence].
+ resize_factor: Resizing factor to scale box coordinates to output frame size.
+ labels: Dictionary of labels and colors keyed on the classification index.
+ """
+ for detection in detections:
+ class_idx, box, confidence = [d for d in detection]
+ label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
+
+ # Obtain frame size and resized bounding box positions
+ frame_height, frame_width = frame.shape[:2]
+ x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
+
+ # Ensure box stays within the frame
+ x_min, y_min = max(0, x_min), max(0, y_min)
+ x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
+
+ # Draw bounding box around detected object
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
+
+ # Create label for detected object class
+ label = f'{label} {confidence * 100:.1f}%'
+ label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
+
+ # Make sure label always stays on-screen
+ x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
+
+ lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
+ lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
+ lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
+
+ # Add label and confidence value
+ cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
+ cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
+ label_color, 1, cv2.LINE_AA)
diff --git a/python/pyarmnn/examples/object_detection/yolo.py b/python/pyarmnn/examples/object_detection/yolo.py
new file mode 100644
index 0000000000..1748d158a2
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/yolo.py
@@ -0,0 +1,98 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Contains functions specific to decoding and processing inference results for YOLO V3 Tiny models.
+"""
+
+import cv2
+import numpy as np
+
+
+def iou(box1: list, box2: list):
+ """
+ Calculates the intersection-over-union (IoU) value for two bounding boxes.
+
+ Args:
+ box1: Array of positions for first bounding box
+ in the form [x_min, y_min, x_max, y_max].
+ box2: Array of positions for second bounding box.
+
+ Returns:
+ Calculated intersection-over-union (IoU) value for two bounding boxes.
+ """
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+
+ if area_box1 <= 0 or area_box2 <= 0:
+ iou_value = 0
+ else:
+ y_min_intersection = max(box1[1], box2[1])
+ x_min_intersection = max(box1[0], box2[0])
+ y_max_intersection = min(box1[3], box2[3])
+ x_max_intersection = min(box1[2], box2[2])
+
+ area_intersection = max(0, y_max_intersection - y_min_intersection) *\
+ max(0, x_max_intersection - x_min_intersection)
+ area_union = area_box1 + area_box2 - area_intersection
+
+ try:
+ iou_value = area_intersection / area_union
+ except ZeroDivisionError:
+ iou_value = 0
+
+ return iou_value
+
+
+def yolo_processing(output: np.ndarray, confidence_threshold=0.40, iou_threshold=0.40):
+ """
+ Performs non-maximum suppression on input detections. Any detections
+ with IOU value greater than given threshold are suppressed.
+
+ Args:
+ output: Vector of outputs from network.
+ confidence_threshold: Selects only strong detections above this value.
+ iou_threshold: Filters out boxes with IOU values above this value.
+
+ Returns:
+ A list of detected objects in the form [class, [box positions], confidence]
+ """
+ if len(output) != 1:
+ raise RuntimeError('Number of outputs from YOLO model does not equal 1')
+
+ # Find the array index of detections with confidence value above threshold
+ confidence_det = output[0][:, :, 4][0]
+ detections = list(np.where(confidence_det > confidence_threshold)[0])
+ all_det, nms_det = [], []
+
+ # Create list of all detections above confidence threshold
+ for d in detections:
+ box_positions = list(output[0][:, d, :4][0])
+ confidence_score = output[0][:, d, 4][0]
+ class_idx = np.argmax(output[0][:, d, 5:])
+ all_det.append((class_idx, box_positions, confidence_score))
+
+ # Suppress detections with IOU value above threshold
+ while all_det:
+ element = int(np.argmax([all_det[i][2] for i in range(len(all_det))]))
+ nms_det.append(all_det.pop(element))
+ all_det = [*filter(lambda x: (iou(x[1], nms_det[-1][1]) <= iou_threshold), [det for det in all_det])]
+ return nms_det
+
+
+def yolo_resize_factor(video: cv2.VideoCapture, input_binding_info: tuple):
+ """
+ Gets a multiplier to scale the bounding box positions to
+ their correct position in the frame.
+
+ Args:
+ video: Video capture object, contains information about data source.
+ input_binding_info: Contains shape of model input layer.
+
+ Returns:
+ Resizing factor to scale box coordinates to output frame size.
+ """
+ frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
+ model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+ return max(frame_height, frame_width) / max(model_height, model_width)
diff --git a/python/pyarmnn/examples/object_detection/yolo_labels.txt b/python/pyarmnn/examples/object_detection/yolo_labels.txt
new file mode 100644
index 0000000000..c5b80f7022
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/yolo_labels.txt
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+dining table
+toilet
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush \ No newline at end of file