diff options
Diffstat (limited to 'python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py')
-rw-r--r-- | python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py index 229a9b6778..6b35f63a00 100644 --- a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py +++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py @@ -2,54 +2,53 @@ # Copyright © 2020 NXP and Contributors. All rights reserved. # SPDX-License-Identifier: MIT -import numpy as np -import pyarmnn as ann import example_utils as eu import os -args = eu.parse_command_line() +if __name__ == "__main__": + args = eu.parse_command_line() -# names of the files in the archive -labels_filename = 'labels_mobilenet_quant_v1_224.txt' -model_filename = 'mobilenet_v1_1.0_224_quant.tflite' -archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip' + # names of the files in the archive + labels_filename = 'labels_mobilenet_quant_v1_224.txt' + model_filename = 'mobilenet_v1_1.0_224_quant.tflite' + archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip' -archive_url = \ - 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip' + archive_url = \ + 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip' -model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, - archive_filename, archive_url) + model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, + archive_filename, archive_url) -image_filenames = eu.get_images(args.data_dir) + image_filenames = eu.get_images(args.data_dir) -# all 3 resources must exist to proceed further -assert os.path.exists(labels_filename) -assert os.path.exists(model_filename) -assert image_filenames -for im in image_filenames: - assert(os.path.exists(im)) + # all 3 resources must exist to proceed further + assert os.path.exists(labels_filename) + assert os.path.exists(model_filename) + assert image_filenames + for im in image_filenames: + assert(os.path.exists(im)) -# Create a network from the model file -net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename) + # Create a network from the model file + net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename) -# Load input information from the model -# tflite has all the need information in the model unlike other formats -input_names = parser.GetSubgraphInputTensorNames(graph_id) -assert len(input_names) == 1 # there should be 1 input tensor in mobilenet + # Load input information from the model + # tflite has all the need information in the model unlike other formats + input_names = parser.GetSubgraphInputTensorNames(graph_id) + assert len(input_names) == 1 # there should be 1 input tensor in mobilenet -input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) -input_width = input_binding_info[1].GetShape()[1] -input_height = input_binding_info[1].GetShape()[2] + input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) + input_width = input_binding_info[1].GetShape()[1] + input_height = input_binding_info[1].GetShape()[2] -# Load output information from the model and create output tensors -output_names = parser.GetSubgraphOutputTensorNames(graph_id) -assert len(output_names) == 1 # and only one output tensor -output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0]) + # Load output information from the model and create output tensors + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + assert len(output_names) == 1 # and only one output tensor + output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0]) -# Load labels file -labels = eu.load_labels(labels_filename) + # Load labels file + labels = eu.load_labels(labels_filename) -# Load images and resize to expected size -images = eu.load_images(image_filenames, input_width, input_height) + # Load images and resize to expected size + images = eu.load_images(image_filenames, input_width, input_height) -eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) + eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) |