aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
diff options
context:
space:
mode:
authorÉanna Ó Catháin <eanna.ocathain@arm.com>2020-11-16 14:12:11 +0000
committerJim Flynn <jim.flynn@arm.com>2020-11-17 12:23:56 +0000
commit145c88f851d12d2cadc2f080d232c1d5963d6e47 (patch)
tree6ae197d74782cd2c7ef8965f4b36acabc65ce453 /python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
parentaa41d5d2f43790938f3a32586626be5ef55b6ca9 (diff)
downloadarmnn-145c88f851d12d2cadc2f080d232c1d5963d6e47.tar.gz
MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21 Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py')
-rw-r--r--python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py91
1 files changed, 46 insertions, 45 deletions
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
index 9e95f76dcc..be28b585ba 100644
--- a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
@@ -44,48 +44,49 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-args = eu.parse_command_line()
-
-model_filename = 'mobilenetv2-1.0.onnx'
-labels_filename = 'synset.txt'
-archive_filename = 'mobilenetv2-1.0.zip'
-labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
-model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
-
-# Download resources
-image_filenames = eu.get_images(args.data_dir)
-
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename,
- [model_url, labels_url])
-
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert (os.path.exists(im))
-
-# Create a network from a model file
-net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
-# Load input information from the model and create input tensors
-input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
-# Load output information from the model and create output tensors
-output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
-output_tensors = ann.make_output_tensors([output_binding_info])
-
-# Load labels
-labels = eu.load_labels(labels_filename)
-
-# Load images and resize to expected size
-images = eu.load_images(image_filenames,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+if __name__ == "__main__":
+ args = eu.parse_command_line()
+
+ model_filename = 'mobilenetv2-1.0.onnx'
+ labels_filename = 'synset.txt'
+ archive_filename = 'mobilenetv2-1.0.zip'
+ labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+ model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+ # Download resources
+ image_filenames = eu.get_images(args.data_dir)
+
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert (os.path.exists(im))
+
+ # Create a network from a model file
+ net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+ # Load input information from the model and create input tensors
+ input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+ # Load output information from the model and create output tensors
+ output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+ output_tensors = ann.make_output_tensors([output_binding_info])
+
+ # Load labels
+ labels = eu.load_labels(labels_filename)
+
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)