aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorPavel Macenauer <pavel.macenauer@linaro.org>2020-06-02 11:54:59 +0000
committerJim Flynn <jim.flynn@arm.com>2020-08-19 20:44:03 +0000
commit09daef8e9d345cc5e95ee9c9a0ff21b1981da483 (patch)
tree2dfb8ab7ce0bea73d968a39043f55db4a99e93b7 /python
parentc84e45d933a9b45810a3bb88f6873f4eddca0975 (diff)
downloadarmnn-09daef8e9d345cc5e95ee9c9a0ff21b1981da483.tar.gz
Update to provide resources to PyArmNN examples manually
Change-Id: I9ee751512abd5d4ec9faca499b5cea7c19028d22 Signed-off-by: Pavel Macenauer <pavel.macenauer@nxp.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyarmnn/examples/example_utils.py189
-rwxr-xr-xpython/pyarmnn/examples/onnx_mobilenetv2.py88
-rwxr-xr-xpython/pyarmnn/examples/tflite_mobilenetv1_quantized.py85
3 files changed, 243 insertions, 119 deletions
diff --git a/python/pyarmnn/examples/example_utils.py b/python/pyarmnn/examples/example_utils.py
index 5ef30f2331..e5425dde52 100644
--- a/python/pyarmnn/examples/example_utils.py
+++ b/python/pyarmnn/examples/example_utils.py
@@ -2,27 +2,77 @@
# SPDX-License-Identifier: MIT
from urllib.parse import urlparse
-import os
from PIL import Image
+from zipfile import ZipFile
+import os
import pyarmnn as ann
import numpy as np
import requests
import argparse
import warnings
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+ """Runs inference on a set of images.
+
+ Args:
+ runtime: Arm NN runtime
+ net_id: Network ID
+ images: Loaded images to run inference on
+ labels: Loaded labels per class
+ input_binding_info: Network input information
+ output_binding_info: Network output information
+
+ Returns:
+ None
+ """
+ output_tensors = ann.make_output_tensors([output_binding_info])
+ for idx, im in enumerate(images):
+ # Create input tensors
+ input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+ # Run inference
+ print("Running inference({0}) ...".format(idx))
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+ # Process output
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ results = np.argsort(out_tensor)[::-1]
+ print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+ """Unzips a file.
+
+ Args:
+ filename(str): Name of the file
+
+ Returns:
+ None
+ """
+ with ZipFile(filename, 'r') as zip_obj:
+ zip_obj.extractall()
+
def parse_command_line(desc: str = ""):
"""Adds arguments to the script.
Args:
- desc(str): Script description.
+ desc (str): Script description
Returns:
- Namespace: Arguments to the script command.
+ Namespace: Arguments to the script command
"""
parser = argparse.ArgumentParser(description=desc)
parser.add_argument("-v", "--verbose", help="Increase output verbosity",
action="store_true")
+ parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+ action="store", default="")
+ parser.add_argument("-m", "--model-dir",
+ help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+ default="")
return parser.parse_args()
@@ -30,15 +80,14 @@ def __create_network(model_file: str, backends: list, parser=None):
"""Creates a network based on a file and parser type.
Args:
- model_file (str): Path of the model file.
+ model_file (str): Path of the model file
backends (list): List of backends to use when running inference.
parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
Returns:
- int: Network ID.
- int: Graph ID.
- IParser: TF Lite parser instance.
- IRuntime: Runtime object instance.
+ int: Network ID
+ IParser: TF Lite parser instance
+ IRuntime: Runtime object instance
"""
args = parse_command_line()
options = ann.CreationOptions()
@@ -189,7 +238,7 @@ def print_top_n(N: int, results: list, labels: list, prob: list):
print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
-def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"):
+def download_file(url: str, force: bool = False, filename: str = None):
"""Downloads a file.
Args:
@@ -197,25 +246,113 @@ def download_file(url: str, force: bool = False, filename: str = None, dest: str
force (bool): Forces to download the file even if it exists.
filename (str): Renames the file when set.
+ Raises:
+ RuntimeError: If for some reason download fails.
+
Returns:
str: Path to the downloaded file.
"""
- if filename is None: # extract filename from url when None
- filename = urlparse(url)
- filename = os.path.basename(filename.path)
-
- if str is not None:
- if not os.path.exists(dest):
- os.makedirs(dest)
- filename = os.path.join(dest, filename)
-
- print("Downloading '{0}' from '{1}' ...".format(filename, url))
- if not os.path.exists(filename) or force is True:
- r = requests.get(url)
- with open(filename, 'wb') as f:
- f.write(r.content)
- print("Finished.")
- else:
- print("File already exists.")
+ try:
+ if filename is None: # extract filename from url when None
+ filename = urlparse(url)
+ filename = os.path.basename(filename.path)
+
+ print("Downloading '{0}' from '{1}' ...".format(filename, url))
+ if not os.path.exists(filename) or force is True:
+ r = requests.get(url)
+ with open(filename, 'wb') as f:
+ f.write(r.content)
+ print("Finished.")
+ else:
+ print("File already exists.")
+ except:
+ raise RuntimeError("Unable to download file.")
return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+ """Gets model and labels.
+
+ Args:
+ model_dir(str): Folder in which model and label files can be found
+ model (str): Name of the model file
+ labels (str): Name of the labels file
+ archive (str): Name of the archive file (optional - need to provide only labels and model)
+ download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+ Returns:
+ tuple (str, str): Output label and model filenames
+ """
+ labels = os.path.join(model_dir, labels)
+ model = os.path.join(model_dir, model)
+
+ if os.path.exists(labels) and os.path.exists(model):
+ print("Found model ({0}) and labels ({1}).".format(model, labels))
+ elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+ print("Found archive ({0}). Unzipping ...".format(archive))
+ unzip_file(archive)
+ elif download_url is not None:
+ print("Model, labels or archive not found. Downloading ...".format(archive))
+ try:
+ if isinstance(download_url, str):
+ download_url = [download_url]
+ for dl in download_url:
+ archive = download_file(dl)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
+ except RuntimeError:
+ print("Unable to download file ({}).".format(archive_url))
+
+ if not os.path.exists(labels) or not os.path.exists(model):
+ raise RuntimeError("Unable to provide model and labels.")
+
+ return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+ """Lists files of a certain format in a folder.
+
+ Args:
+ folder (str): Path to the folder to search
+ formats (list): List of supported files
+
+ Returns:
+ list: A list of found files
+ """
+ files = []
+ if folder and not os.path.exists(folder):
+ print("Folder '{}' does not exist.".format(folder))
+ return files
+
+ for file in os.listdir(folder if folder else os.getcwd()):
+ for frmt in formats:
+ if file.lower().endswith(frmt):
+ files.append(os.path.join(folder, file) if folder else file)
+ break # only the format loop
+
+ return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+ """Gets image.
+
+ Args:
+ image (str): Image filename
+ image_url (str): Image url
+
+ Returns:
+ str: Output image filename
+ """
+ images = list_images(image_dir)
+ if not images and image_url is not None:
+ print("No images found. Downloading ...")
+ try:
+ images = [download_file(image_url)]
+ except RuntimeError:
+ print("Unable to download file ({0}).".format(image_url))
+
+ if not images:
+ raise RuntimeError("Unable to provide images.")
+
+ return images
diff --git a/python/pyarmnn/examples/onnx_mobilenetv2.py b/python/pyarmnn/examples/onnx_mobilenetv2.py
index 5ba08499cc..05bfd7b415 100755
--- a/python/pyarmnn/examples/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/onnx_mobilenetv2.py
@@ -4,6 +4,7 @@
import pyarmnn as ann
import numpy as np
+import os
from PIL import Image
import example_utils as eu
@@ -43,45 +44,48 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-if __name__ == "__main__":
- # Download resources
- kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
- labels_filename = eu.download_file('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
- model_filename = eu.download_file(
- 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx')
-
- # Create a network from a model file
- net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
- # Load input information from the model and create input tensors
- input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
- # Load output information from the model and create output tensors
- output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
- output_tensors = ann.make_output_tensors([output_binding_info])
-
- # Load labels
- labels = eu.load_labels(labels_filename)
-
- # Load images and resize to expected size
- image_names = [kitten_filename]
- images = eu.load_images(image_names,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
- for idx, im in enumerate(images):
- # Create input tensors
- input_tensors = ann.make_input_tensors([input_binding_info], [im])
-
- # Run inference
- print("Running inference on '{0}' ...".format(image_names[idx]))
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
-
- # Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
- results = np.argsort(out_tensor)[::-1]
- eu.print_top_n(5, results, labels, out_tensor)
+args = eu.parse_command_line()
+
+model_filename = 'mobilenetv2-1.0.onnx'
+labels_filename = 'synset.txt'
+archive_filename = 'mobilenetv2-1.0.zip'
+labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+# Download resources
+image_filenames = eu.get_images(args.data_dir)
+
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+ assert (os.path.exists(im))
+
+# Create a network from a model file
+net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+# Load input information from the model and create input tensors
+input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+# Load output information from the model and create output tensors
+output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+output_tensors = ann.make_output_tensors([output_binding_info])
+
+# Load labels
+labels = eu.load_labels(labels_filename)
+
+# Load images and resize to expected size
+images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
diff --git a/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py
index aa18a528af..cb2c91cba7 100755
--- a/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py
+++ b/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py
@@ -2,71 +2,54 @@
# Copyright 2020 NXP
# SPDX-License-Identifier: MIT
-from zipfile import ZipFile
import numpy as np
import pyarmnn as ann
import example_utils as eu
import os
+args = eu.parse_command_line()
-def unzip_file(filename):
- """Unzips a file to its current location.
+# names of the files in the archive
+labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
- Args:
- filename (str): Name of the archive.
+archive_url = \
+ 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
- Returns:
- str: Directory path of the extracted files.
- """
- with ZipFile(filename, 'r') as zip_obj:
- zip_obj.extractall(os.path.dirname(filename))
- return os.path.dirname(filename)
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename, archive_url)
+image_filenames = eu.get_images(args.data_dir)
-if __name__ == "__main__":
- # Download resources
- archive_filename = eu.download_file(
- 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip')
- dir_path = unzip_file(archive_filename)
- # names of the files in the archive
- labels_filename = os.path.join(dir_path, 'labels_mobilenet_quant_v1_224.txt')
- model_filename = os.path.join(dir_path, 'mobilenet_v1_1.0_224_quant.tflite')
- kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+ assert(os.path.exists(im))
- # Create a network from the model file
- net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+# Create a network from the model file
+net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
- # Load input information from the model
- # tflite has all the need information in the model unlike other formats
- input_names = parser.GetSubgraphInputTensorNames(graph_id)
- assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
+# Load input information from the model
+# tflite has all the need information in the model unlike other formats
+input_names = parser.GetSubgraphInputTensorNames(graph_id)
+assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
- input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
- input_width = input_binding_info[1].GetShape()[1]
- input_height = input_binding_info[1].GetShape()[2]
+input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+input_width = input_binding_info[1].GetShape()[1]
+input_height = input_binding_info[1].GetShape()[2]
- # Load output information from the model and create output tensors
- output_names = parser.GetSubgraphOutputTensorNames(graph_id)
- assert len(output_names) == 1 # and only one output tensor
- output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
- output_tensors = ann.make_output_tensors([output_binding_info])
+# Load output information from the model and create output tensors
+output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+assert len(output_names) == 1 # and only one output tensor
+output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
- # Load labels file
- labels = eu.load_labels(labels_filename)
+# Load labels file
+labels = eu.load_labels(labels_filename)
- # Load images and resize to expected size
- image_names = [kitten_filename]
- images = eu.load_images(image_names, input_width, input_height)
+# Load images and resize to expected size
+images = eu.load_images(image_filenames, input_width, input_height)
- for idx, im in enumerate(images):
- # Create input tensors
- input_tensors = ann.make_input_tensors([input_binding_info], [im])
-
- # Run inference
- print("Running inference on '{0}' ...".format(image_names[idx]))
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
-
- # Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
- results = np.argsort(out_tensor)[::-1]
- eu.print_top_n(5, results, labels, out_tensor)
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)