aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification')
-rw-r--r--python/pyarmnn/examples/image_classification/example_utils.py20
-rw-r--r--python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py91
-rw-r--r--python/pyarmnn/examples/image_classification/requirements.txt2
-rw-r--r--python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py71
4 files changed, 92 insertions, 92 deletions
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
index 090ce2f5b3..f0ba91e981 100644
--- a/python/pyarmnn/examples/image_classification/example_utils.py
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -38,7 +38,8 @@ def run_inference(runtime, net_id, images, labels, input_binding_info, output_bi
runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
# Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ # output tensor has a shape (1, 1001)
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
results = np.argsort(out_tensor)[::-1]
print_top_n(5, results, labels, out_tensor)
@@ -121,7 +122,7 @@ def __create_network(model_file: str, backends: list, parser=None):
return net_id, parser, runtime
-def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from a tflite model file.
Args:
@@ -140,7 +141,7 @@ def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']
return net_id, graph_id, parser, runtime
-def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from an onnx model file.
Args:
@@ -181,7 +182,7 @@ def preprocess_default(img: Image, width: int, height: int, data_type, scale: fl
def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
- scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
preprocess_fn=preprocess_default):
"""Loads images, resizes and performs any additional preprocessing to run inference.
@@ -218,7 +219,6 @@ def load_labels(label_file: str):
with open(label_file, 'r') as f:
labels = [l.rstrip() for l in f]
return labels
- return None
def print_top_n(N: int, results: list, labels: list, prob: list):
@@ -299,10 +299,10 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
download_url = [download_url]
for dl in download_url:
archive = download_file(dl)
- if dl.lower().endswith(".zip"):
- unzip_file(archive)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
except RuntimeError:
- print("Unable to download file ({}).".format(archive_url))
+ print("Unable to download file ({}).".format(download_url))
if not os.path.exists(labels) or not os.path.exists(model):
raise RuntimeError("Unable to provide model and labels.")
@@ -310,7 +310,7 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
return model, labels
-def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
"""Lists files of a certain format in a folder.
Args:
@@ -338,7 +338,7 @@ def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
"""Gets image.
Args:
- image (str): Image filename
+ image_dir (str): Image filename
image_url (str): Image url
Returns:
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
index 9e95f76dcc..be28b585ba 100644
--- a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
@@ -44,48 +44,49 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-args = eu.parse_command_line()
-
-model_filename = 'mobilenetv2-1.0.onnx'
-labels_filename = 'synset.txt'
-archive_filename = 'mobilenetv2-1.0.zip'
-labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
-model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
-
-# Download resources
-image_filenames = eu.get_images(args.data_dir)
-
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename,
- [model_url, labels_url])
-
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert (os.path.exists(im))
-
-# Create a network from a model file
-net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
-# Load input information from the model and create input tensors
-input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
-# Load output information from the model and create output tensors
-output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
-output_tensors = ann.make_output_tensors([output_binding_info])
-
-# Load labels
-labels = eu.load_labels(labels_filename)
-
-# Load images and resize to expected size
-images = eu.load_images(image_filenames,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+if __name__ == "__main__":
+ args = eu.parse_command_line()
+
+ model_filename = 'mobilenetv2-1.0.onnx'
+ labels_filename = 'synset.txt'
+ archive_filename = 'mobilenetv2-1.0.zip'
+ labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+ model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+ # Download resources
+ image_filenames = eu.get_images(args.data_dir)
+
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert (os.path.exists(im))
+
+ # Create a network from a model file
+ net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+ # Load input information from the model and create input tensors
+ input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+ # Load output information from the model and create output tensors
+ output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+ output_tensors = ann.make_output_tensors([output_binding_info])
+
+ # Load labels
+ labels = eu.load_labels(labels_filename)
+
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
diff --git a/python/pyarmnn/examples/image_classification/requirements.txt b/python/pyarmnn/examples/image_classification/requirements.txt
index f97e85636e..289a2b521a 100644
--- a/python/pyarmnn/examples/image_classification/requirements.txt
+++ b/python/pyarmnn/examples/image_classification/requirements.txt
@@ -1,4 +1,4 @@
requests>=2.23.0
urllib3>=1.25.8
Pillow>=6.1.0
-numpy>=1.18.1
+numpy>=1.19.2
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
index 229a9b6778..6b35f63a00 100644
--- a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
+++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
@@ -2,54 +2,53 @@
# Copyright © 2020 NXP and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
-import numpy as np
-import pyarmnn as ann
import example_utils as eu
import os
-args = eu.parse_command_line()
+if __name__ == "__main__":
+ args = eu.parse_command_line()
-# names of the files in the archive
-labels_filename = 'labels_mobilenet_quant_v1_224.txt'
-model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
-archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
+ # names of the files in the archive
+ labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+ model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+ archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
-archive_url = \
- 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
+ archive_url = \
+ 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename, archive_url)
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename, archive_url)
-image_filenames = eu.get_images(args.data_dir)
+ image_filenames = eu.get_images(args.data_dir)
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert(os.path.exists(im))
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert(os.path.exists(im))
-# Create a network from the model file
-net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+ # Create a network from the model file
+ net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
-# Load input information from the model
-# tflite has all the need information in the model unlike other formats
-input_names = parser.GetSubgraphInputTensorNames(graph_id)
-assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
+ # Load input information from the model
+ # tflite has all the need information in the model unlike other formats
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
-input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
-input_width = input_binding_info[1].GetShape()[1]
-input_height = input_binding_info[1].GetShape()[2]
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+ input_width = input_binding_info[1].GetShape()[1]
+ input_height = input_binding_info[1].GetShape()[2]
-# Load output information from the model and create output tensors
-output_names = parser.GetSubgraphOutputTensorNames(graph_id)
-assert len(output_names) == 1 # and only one output tensor
-output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
+ # Load output information from the model and create output tensors
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ assert len(output_names) == 1 # and only one output tensor
+ output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
-# Load labels file
-labels = eu.load_labels(labels_filename)
+ # Load labels file
+ labels = eu.load_labels(labels_filename)
-# Load images and resize to expected size
-images = eu.load_images(image_filenames, input_width, input_height)
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames, input_width, input_height)
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)