aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py')
-rw-r--r--python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py91
1 files changed, 46 insertions, 45 deletions
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
index 9e95f76dcc..be28b585ba 100644
--- a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
@@ -44,48 +44,49 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-args = eu.parse_command_line()
-
-model_filename = 'mobilenetv2-1.0.onnx'
-labels_filename = 'synset.txt'
-archive_filename = 'mobilenetv2-1.0.zip'
-labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
-model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
-
-# Download resources
-image_filenames = eu.get_images(args.data_dir)
-
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename,
- [model_url, labels_url])
-
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert (os.path.exists(im))
-
-# Create a network from a model file
-net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
-# Load input information from the model and create input tensors
-input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
-# Load output information from the model and create output tensors
-output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
-output_tensors = ann.make_output_tensors([output_binding_info])
-
-# Load labels
-labels = eu.load_labels(labels_filename)
-
-# Load images and resize to expected size
-images = eu.load_images(image_filenames,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+if __name__ == "__main__":
+ args = eu.parse_command_line()
+
+ model_filename = 'mobilenetv2-1.0.onnx'
+ labels_filename = 'synset.txt'
+ archive_filename = 'mobilenetv2-1.0.zip'
+ labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+ model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+ # Download resources
+ image_filenames = eu.get_images(args.data_dir)
+
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert (os.path.exists(im))
+
+ # Create a network from a model file
+ net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+ # Load input information from the model and create input tensors
+ input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+ # Load output information from the model and create output tensors
+ output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+ output_tensors = ann.make_output_tensors([output_binding_info])
+
+ # Load labels
+ labels = eu.load_labels(labels_filename)
+
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)