diff options
Diffstat (limited to 'python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py')
-rw-r--r-- | python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py | 91 |
1 files changed, 46 insertions, 45 deletions
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py index 9e95f76dcc..be28b585ba 100644 --- a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py +++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py @@ -44,48 +44,49 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float return img -args = eu.parse_command_line() - -model_filename = 'mobilenetv2-1.0.onnx' -labels_filename = 'synset.txt' -archive_filename = 'mobilenetv2-1.0.zip' -labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename -model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename - -# Download resources -image_filenames = eu.get_images(args.data_dir) - -model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, - archive_filename, - [model_url, labels_url]) - -# all 3 resources must exist to proceed further -assert os.path.exists(labels_filename) -assert os.path.exists(model_filename) -assert image_filenames -for im in image_filenames: - assert (os.path.exists(im)) - -# Create a network from a model file -net_id, parser, runtime = eu.create_onnx_network(model_filename) - -# Load input information from the model and create input tensors -input_binding_info = parser.GetNetworkInputBindingInfo("data") - -# Load output information from the model and create output tensors -output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0") -output_tensors = ann.make_output_tensors([output_binding_info]) - -# Load labels -labels = eu.load_labels(labels_filename) - -# Load images and resize to expected size -images = eu.load_images(image_filenames, - 224, 224, - np.float32, - 255.0, - [0.485, 0.456, 0.406], - [0.229, 0.224, 0.225], - preprocess_onnx) - -eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) +if __name__ == "__main__": + args = eu.parse_command_line() + + model_filename = 'mobilenetv2-1.0.onnx' + labels_filename = 'synset.txt' + archive_filename = 'mobilenetv2-1.0.zip' + labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename + model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename + + # Download resources + image_filenames = eu.get_images(args.data_dir) + + model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, + archive_filename, + [model_url, labels_url]) + + # all 3 resources must exist to proceed further + assert os.path.exists(labels_filename) + assert os.path.exists(model_filename) + assert image_filenames + for im in image_filenames: + assert (os.path.exists(im)) + + # Create a network from a model file + net_id, parser, runtime = eu.create_onnx_network(model_filename) + + # Load input information from the model and create input tensors + input_binding_info = parser.GetNetworkInputBindingInfo("data") + + # Load output information from the model and create output tensors + output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0") + output_tensors = ann.make_output_tensors([output_binding_info]) + + # Load labels + labels = eu.load_labels(labels_filename) + + # Load images and resize to expected size + images = eu.load_images(image_filenames, + 224, 224, + np.float32, + 255.0, + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], + preprocess_onnx) + + eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) |