aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/onnx_mobilenetv2.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/onnx_mobilenetv2.py')
-rwxr-xr-xpython/pyarmnn/examples/onnx_mobilenetv2.py88
1 files changed, 46 insertions, 42 deletions
diff --git a/python/pyarmnn/examples/onnx_mobilenetv2.py b/python/pyarmnn/examples/onnx_mobilenetv2.py
index 5ba08499cc..05bfd7b415 100755
--- a/python/pyarmnn/examples/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/onnx_mobilenetv2.py
@@ -4,6 +4,7 @@
import pyarmnn as ann
import numpy as np
+import os
from PIL import Image
import example_utils as eu
@@ -43,45 +44,48 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-if __name__ == "__main__":
- # Download resources
- kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
- labels_filename = eu.download_file('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
- model_filename = eu.download_file(
- 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx')
-
- # Create a network from a model file
- net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
- # Load input information from the model and create input tensors
- input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
- # Load output information from the model and create output tensors
- output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
- output_tensors = ann.make_output_tensors([output_binding_info])
-
- # Load labels
- labels = eu.load_labels(labels_filename)
-
- # Load images and resize to expected size
- image_names = [kitten_filename]
- images = eu.load_images(image_names,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
- for idx, im in enumerate(images):
- # Create input tensors
- input_tensors = ann.make_input_tensors([input_binding_info], [im])
-
- # Run inference
- print("Running inference on '{0}' ...".format(image_names[idx]))
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
-
- # Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
- results = np.argsort(out_tensor)[::-1]
- eu.print_top_n(5, results, labels, out_tensor)
+args = eu.parse_command_line()
+
+model_filename = 'mobilenetv2-1.0.onnx'
+labels_filename = 'synset.txt'
+archive_filename = 'mobilenetv2-1.0.zip'
+labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+# Download resources
+image_filenames = eu.get_images(args.data_dir)
+
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+ assert (os.path.exists(im))
+
+# Create a network from a model file
+net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+# Load input information from the model and create input tensors
+input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+# Load output information from the model and create output tensors
+output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+output_tensors = ann.make_output_tensors([output_binding_info])
+
+# Load labels
+labels = eu.load_labels(labels_filename)
+
+# Load images and resize to expected size
+images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)