diff options
Diffstat (limited to 'python/pyarmnn/examples/onnx_mobilenetv2.py')
-rwxr-xr-x | python/pyarmnn/examples/onnx_mobilenetv2.py | 88 |
1 files changed, 46 insertions, 42 deletions
diff --git a/python/pyarmnn/examples/onnx_mobilenetv2.py b/python/pyarmnn/examples/onnx_mobilenetv2.py index 5ba08499cc..05bfd7b415 100755 --- a/python/pyarmnn/examples/onnx_mobilenetv2.py +++ b/python/pyarmnn/examples/onnx_mobilenetv2.py @@ -4,6 +4,7 @@ import pyarmnn as ann import numpy as np +import os from PIL import Image import example_utils as eu @@ -43,45 +44,48 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float return img -if __name__ == "__main__": - # Download resources - kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg') - labels_filename = eu.download_file('https://s3.amazonaws.com/onnx-model-zoo/synset.txt') - model_filename = eu.download_file( - 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx') - - # Create a network from a model file - net_id, parser, runtime = eu.create_onnx_network(model_filename) - - # Load input information from the model and create input tensors - input_binding_info = parser.GetNetworkInputBindingInfo("data") - - # Load output information from the model and create output tensors - output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0") - output_tensors = ann.make_output_tensors([output_binding_info]) - - # Load labels - labels = eu.load_labels(labels_filename) - - # Load images and resize to expected size - image_names = [kitten_filename] - images = eu.load_images(image_names, - 224, 224, - np.float32, - 255.0, - [0.485, 0.456, 0.406], - [0.229, 0.224, 0.225], - preprocess_onnx) - - for idx, im in enumerate(images): - # Create input tensors - input_tensors = ann.make_input_tensors([input_binding_info], [im]) - - # Run inference - print("Running inference on '{0}' ...".format(image_names[idx])) - runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) - - # Process output - out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0] - results = np.argsort(out_tensor)[::-1] - eu.print_top_n(5, results, labels, out_tensor) +args = eu.parse_command_line() + +model_filename = 'mobilenetv2-1.0.onnx' +labels_filename = 'synset.txt' +archive_filename = 'mobilenetv2-1.0.zip' +labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename +model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename + +# Download resources +image_filenames = eu.get_images(args.data_dir) + +model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, + archive_filename, + [model_url, labels_url]) + +# all 3 resources must exist to proceed further +assert os.path.exists(labels_filename) +assert os.path.exists(model_filename) +assert image_filenames +for im in image_filenames: + assert (os.path.exists(im)) + +# Create a network from a model file +net_id, parser, runtime = eu.create_onnx_network(model_filename) + +# Load input information from the model and create input tensors +input_binding_info = parser.GetNetworkInputBindingInfo("data") + +# Load output information from the model and create output tensors +output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0") +output_tensors = ann.make_output_tensors([output_binding_info]) + +# Load labels +labels = eu.load_labels(labels_filename) + +# Load images and resize to expected size +images = eu.load_images(image_filenames, + 224, 224, + np.float32, + 255.0, + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], + preprocess_onnx) + +eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) |