diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2020-11-16 14:12:11 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-11-17 12:23:56 +0000 |
commit | 145c88f851d12d2cadc2f080d232c1d5963d6e47 (patch) | |
tree | 6ae197d74782cd2c7ef8965f4b36acabc65ce453 /python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py | |
parent | aa41d5d2f43790938f3a32586626be5ef55b6ca9 (diff) | |
download | armnn-145c88f851d12d2cadc2f080d232c1d5963d6e47.tar.gz |
MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py')
-rw-r--r-- | python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py index 229a9b6778..6b35f63a00 100644 --- a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py +++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py @@ -2,54 +2,53 @@ # Copyright © 2020 NXP and Contributors. All rights reserved. # SPDX-License-Identifier: MIT -import numpy as np -import pyarmnn as ann import example_utils as eu import os -args = eu.parse_command_line() +if __name__ == "__main__": + args = eu.parse_command_line() -# names of the files in the archive -labels_filename = 'labels_mobilenet_quant_v1_224.txt' -model_filename = 'mobilenet_v1_1.0_224_quant.tflite' -archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip' + # names of the files in the archive + labels_filename = 'labels_mobilenet_quant_v1_224.txt' + model_filename = 'mobilenet_v1_1.0_224_quant.tflite' + archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip' -archive_url = \ - 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip' + archive_url = \ + 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip' -model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, - archive_filename, archive_url) + model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, + archive_filename, archive_url) -image_filenames = eu.get_images(args.data_dir) + image_filenames = eu.get_images(args.data_dir) -# all 3 resources must exist to proceed further -assert os.path.exists(labels_filename) -assert os.path.exists(model_filename) -assert image_filenames -for im in image_filenames: - assert(os.path.exists(im)) + # all 3 resources must exist to proceed further + assert os.path.exists(labels_filename) + assert os.path.exists(model_filename) + assert image_filenames + for im in image_filenames: + assert(os.path.exists(im)) -# Create a network from the model file -net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename) + # Create a network from the model file + net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename) -# Load input information from the model -# tflite has all the need information in the model unlike other formats -input_names = parser.GetSubgraphInputTensorNames(graph_id) -assert len(input_names) == 1 # there should be 1 input tensor in mobilenet + # Load input information from the model + # tflite has all the need information in the model unlike other formats + input_names = parser.GetSubgraphInputTensorNames(graph_id) + assert len(input_names) == 1 # there should be 1 input tensor in mobilenet -input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) -input_width = input_binding_info[1].GetShape()[1] -input_height = input_binding_info[1].GetShape()[2] + input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) + input_width = input_binding_info[1].GetShape()[1] + input_height = input_binding_info[1].GetShape()[2] -# Load output information from the model and create output tensors -output_names = parser.GetSubgraphOutputTensorNames(graph_id) -assert len(output_names) == 1 # and only one output tensor -output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0]) + # Load output information from the model and create output tensors + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + assert len(output_names) == 1 # and only one output tensor + output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0]) -# Load labels file -labels = eu.load_labels(labels_filename) + # Load labels file + labels = eu.load_labels(labels_filename) -# Load images and resize to expected size -images = eu.load_images(image_filenames, input_width, input_height) + # Load images and resize to expected size + images = eu.load_images(image_filenames, input_width, input_height) -eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) + eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) |