aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py')
-rw-r--r--python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py71
1 files changed, 35 insertions, 36 deletions
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
index 229a9b6778..6b35f63a00 100644
--- a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
+++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
@@ -2,54 +2,53 @@
# Copyright © 2020 NXP and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
-import numpy as np
-import pyarmnn as ann
import example_utils as eu
import os
-args = eu.parse_command_line()
+if __name__ == "__main__":
+ args = eu.parse_command_line()
-# names of the files in the archive
-labels_filename = 'labels_mobilenet_quant_v1_224.txt'
-model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
-archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
+ # names of the files in the archive
+ labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+ model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+ archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
-archive_url = \
- 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
+ archive_url = \
+ 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename, archive_url)
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename, archive_url)
-image_filenames = eu.get_images(args.data_dir)
+ image_filenames = eu.get_images(args.data_dir)
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert(os.path.exists(im))
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert(os.path.exists(im))
-# Create a network from the model file
-net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+ # Create a network from the model file
+ net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
-# Load input information from the model
-# tflite has all the need information in the model unlike other formats
-input_names = parser.GetSubgraphInputTensorNames(graph_id)
-assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
+ # Load input information from the model
+ # tflite has all the need information in the model unlike other formats
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
-input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
-input_width = input_binding_info[1].GetShape()[1]
-input_height = input_binding_info[1].GetShape()[2]
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+ input_width = input_binding_info[1].GetShape()[1]
+ input_height = input_binding_info[1].GetShape()[2]
-# Load output information from the model and create output tensors
-output_names = parser.GetSubgraphOutputTensorNames(graph_id)
-assert len(output_names) == 1 # and only one output tensor
-output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
+ # Load output information from the model and create output tensors
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ assert len(output_names) == 1 # and only one output tensor
+ output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
-# Load labels file
-labels = eu.load_labels(labels_filename)
+ # Load labels file
+ labels = eu.load_labels(labels_filename)
-# Load images and resize to expected size
-images = eu.load_images(image_filenames, input_width, input_height)
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames, input_width, input_height)
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)