aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification')
-rw-r--r--python/pyarmnn/examples/image_classification/README.md46
-rw-r--r--python/pyarmnn/examples/image_classification/example_utils.py358
-rw-r--r--python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py91
-rw-r--r--python/pyarmnn/examples/image_classification/requirements.txt4
-rw-r--r--python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py55
5 files changed, 554 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/image_classification/README.md b/python/pyarmnn/examples/image_classification/README.md
new file mode 100644
index 0000000000..61efbc421f
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/README.md
@@ -0,0 +1,46 @@
+# PyArmNN Image Classification Sample Application
+
+## Overview
+
+To further explore PyArmNN API, we provide an example for running image classification on an image.
+
+All resources are downloaded during execution, so if you do not have access to the internet, you may need to download these manually. The file `example_utils.py` contains code shared between the examples.
+
+## Prerequisites
+
+##### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'22.0.0'
+```
+
+##### Dependencies
+
+Install the dependencies:
+
+```bash
+$ pip install -r requirements.txt
+```
+
+## Perform Image Classification
+
+Perform inference with TFLite model by running the sample script:
+```bash
+$ python tflite_mobilenetv1_quantized.py
+```
+
+Perform inference with ONNX model by running the sample script:
+```bash
+$ python onnx_mobilenetv2.py
+```
+
+The output from inference will be printed as <i>Top N</i> results, listing the classes and probabilities associated with the classified image.
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
new file mode 100644
index 0000000000..090ce2f5b3
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -0,0 +1,358 @@
+# Copyright © 2020 NXP and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from urllib.parse import urlparse
+from PIL import Image
+from zipfile import ZipFile
+import os
+import pyarmnn as ann
+import numpy as np
+import requests
+import argparse
+import warnings
+
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+ """Runs inference on a set of images.
+
+ Args:
+ runtime: Arm NN runtime
+ net_id: Network ID
+ images: Loaded images to run inference on
+ labels: Loaded labels per class
+ input_binding_info: Network input information
+ output_binding_info: Network output information
+
+ Returns:
+ None
+ """
+ output_tensors = ann.make_output_tensors([output_binding_info])
+ for idx, im in enumerate(images):
+ # Create input tensors
+ input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+ # Run inference
+ print("Running inference({0}) ...".format(idx))
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+ # Process output
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ results = np.argsort(out_tensor)[::-1]
+ print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+ """Unzips a file.
+
+ Args:
+ filename(str): Name of the file
+
+ Returns:
+ None
+ """
+ with ZipFile(filename, 'r') as zip_obj:
+ zip_obj.extractall()
+
+
+def parse_command_line(desc: str = ""):
+ """Adds arguments to the script.
+
+ Args:
+ desc (str): Script description
+
+ Returns:
+ Namespace: Arguments to the script command
+ """
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument("-v", "--verbose", help="Increase output verbosity",
+ action="store_true")
+ parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+ action="store", default="")
+ parser.add_argument("-m", "--model-dir",
+ help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+ default="")
+ return parser.parse_args()
+
+
+def __create_network(model_file: str, backends: list, parser=None):
+ """Creates a network based on a file and parser type.
+
+ Args:
+ model_file (str): Path of the model file
+ backends (list): List of backends to use when running inference.
+ parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
+
+ Returns:
+ int: Network ID
+ IParser: TF Lite parser instance
+ IRuntime: Runtime object instance
+ """
+ args = parse_command_line()
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+
+ if parser is None:
+ # try to determine what parser to create based on model extension
+ _, ext = os.path.splitext(model_file)
+ if ext == ".onnx":
+ parser = ann.IOnnxParser()
+ elif ext == ".tflite":
+ parser = ann.ITfLiteParser()
+ assert (parser is not None)
+
+ network = parser.CreateNetworkFromBinaryFile(model_file)
+
+ preferred_backends = []
+ for b in backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ if args.verbose:
+ for m in messages:
+ warnings.warn(m)
+
+ net_id, w = runtime.LoadNetwork(opt_network)
+ if args.verbose and w:
+ warnings.warn(w)
+
+ return net_id, parser, runtime
+
+
+def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+ """Creates a network from a tflite model file.
+
+ Args:
+ model_file (str): Path of the model file.
+ backends (list): List of backends to use when running inference.
+
+ Returns:
+ int: Network ID.
+ int: Graph ID.
+ ITFliteParser: TF Lite parser instance.
+ IRuntime: Runtime object instance.
+ """
+ net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
+ graph_id = parser.GetSubgraphCount() - 1
+
+ return net_id, graph_id, parser, runtime
+
+
+def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+ """Creates a network from an onnx model file.
+
+ Args:
+ model_file (str): Path of the model file.
+ backends (list): List of backends to use when running inference.
+
+ Returns:
+ int: Network ID.
+ IOnnxParser: ONNX parser instance.
+ IRuntime: Runtime object instance.
+ """
+ return __create_network(model_file, backends, ann.IOnnxParser())
+
+
+def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+ stddev: list):
+ """Default preprocessing image function.
+
+ Args:
+ img (PIL.Image): PIL.Image object instance.
+ width (int): Width to resize to.
+ height (int): Height to resize to.
+ data_type: Data Type to cast the image to.
+ scale (float): Scaling value.
+ mean (list): RGB mean offset.
+ stddev (list): RGB standard deviation.
+
+ Returns:
+ np.array: Resized and preprocessed image.
+ """
+ img = img.resize((width, height), Image.BILINEAR)
+ img = img.convert('RGB')
+ img = np.array(img)
+ img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
+ img = ((img / scale) - mean) / stddev
+ img = img.flatten().astype(data_type)
+ return img
+
+
+def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
+ scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ preprocess_fn=preprocess_default):
+ """Loads images, resizes and performs any additional preprocessing to run inference.
+
+ Args:
+ img (list): List of PIL.Image object instances.
+ input_width (int): Width to resize to.
+ input_height (int): Height to resize to.
+ data_type: Data Type to cast the image to.
+ scale (float): Scaling value.
+ mean (list): RGB mean offset.
+ stddev (list): RGB standard deviation.
+ preprocess_fn: Preprocessing function.
+
+ Returns:
+ np.array: Resized and preprocessed images.
+ """
+ images = []
+ for i in image_files:
+ img = Image.open(i)
+ img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
+ images.append(img)
+ return images
+
+
+def load_labels(label_file: str):
+ """Loads a labels file containing a label per line.
+
+ Args:
+ label_file (str): Labels file path.
+
+ Returns:
+ list: List of labels read from a file.
+ """
+ with open(label_file, 'r') as f:
+ labels = [l.rstrip() for l in f]
+ return labels
+ return None
+
+
+def print_top_n(N: int, results: list, labels: list, prob: list):
+ """Prints TOP-N results
+
+ Args:
+ N (int): Result count to print.
+ results (list): Top prediction indices.
+ labels (list): A list of labels for every class.
+ prob (list): A list of probabilities for every class.
+
+ Returns:
+ None
+ """
+ assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
+ for i in range(min(len(results), N)):
+ print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
+
+
+def download_file(url: str, force: bool = False, filename: str = None):
+ """Downloads a file.
+
+ Args:
+ url (str): File url.
+ force (bool): Forces to download the file even if it exists.
+ filename (str): Renames the file when set.
+
+ Raises:
+ RuntimeError: If for some reason download fails.
+
+ Returns:
+ str: Path to the downloaded file.
+ """
+ try:
+ if filename is None: # extract filename from url when None
+ filename = urlparse(url)
+ filename = os.path.basename(filename.path)
+
+ print("Downloading '{0}' from '{1}' ...".format(filename, url))
+ if not os.path.exists(filename) or force is True:
+ r = requests.get(url)
+ with open(filename, 'wb') as f:
+ f.write(r.content)
+ print("Finished.")
+ else:
+ print("File already exists.")
+ except:
+ raise RuntimeError("Unable to download file.")
+
+ return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+ """Gets model and labels.
+
+ Args:
+ model_dir(str): Folder in which model and label files can be found
+ model (str): Name of the model file
+ labels (str): Name of the labels file
+ archive (str): Name of the archive file (optional - need to provide only labels and model)
+ download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+ Returns:
+ tuple (str, str): Output label and model filenames
+ """
+ labels = os.path.join(model_dir, labels)
+ model = os.path.join(model_dir, model)
+
+ if os.path.exists(labels) and os.path.exists(model):
+ print("Found model ({0}) and labels ({1}).".format(model, labels))
+ elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+ print("Found archive ({0}). Unzipping ...".format(archive))
+ unzip_file(archive)
+ elif download_url is not None:
+ print("Model, labels or archive not found. Downloading ...".format(archive))
+ try:
+ if isinstance(download_url, str):
+ download_url = [download_url]
+ for dl in download_url:
+ archive = download_file(dl)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
+ except RuntimeError:
+ print("Unable to download file ({}).".format(archive_url))
+
+ if not os.path.exists(labels) or not os.path.exists(model):
+ raise RuntimeError("Unable to provide model and labels.")
+
+ return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+ """Lists files of a certain format in a folder.
+
+ Args:
+ folder (str): Path to the folder to search
+ formats (list): List of supported files
+
+ Returns:
+ list: A list of found files
+ """
+ files = []
+ if folder and not os.path.exists(folder):
+ print("Folder '{}' does not exist.".format(folder))
+ return files
+
+ for file in os.listdir(folder if folder else os.getcwd()):
+ for frmt in formats:
+ if file.lower().endswith(frmt):
+ files.append(os.path.join(folder, file) if folder else file)
+ break # only the format loop
+
+ return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+ """Gets image.
+
+ Args:
+ image (str): Image filename
+ image_url (str): Image url
+
+ Returns:
+ str: Output image filename
+ """
+ images = list_images(image_dir)
+ if not images and image_url is not None:
+ print("No images found. Downloading ...")
+ try:
+ images = [download_file(image_url)]
+ except RuntimeError:
+ print("Unable to download file ({0}).".format(image_url))
+
+ if not images:
+ raise RuntimeError("Unable to provide images.")
+
+ return images
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
new file mode 100644
index 0000000000..9e95f76dcc
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+# Copyright © 2020 NXP and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import pyarmnn as ann
+import numpy as np
+import os
+from PIL import Image
+import example_utils as eu
+
+
+def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+ stddev: list):
+ """Preprocessing function for ONNX imagenet models based on:
+ https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
+
+ Args:
+ img (PIL.Image): Loaded PIL.Image
+ width (int): Target image width
+ height (int): Target image height
+ data_type: Image datatype (np.uint8 or np.float32)
+ scale (float): Scaling factor
+ mean: RGB mean values
+ stddev: RGB standard deviation
+
+ Returns:
+ np.array: Preprocess image as Numpy array
+ """
+ img = img.resize((256, 256), Image.BILINEAR)
+ # first rescale to 256,256 and then center crop
+ left = (256 - width) / 2
+ top = (256 - height) / 2
+ right = (256 + width) / 2
+ bottom = (256 + height) / 2
+ img = img.crop((left, top, right, bottom))
+ img = img.convert('RGB')
+ img = np.array(img)
+ img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
+ img = ((img / scale) - mean) / stddev
+ # NHWC to NCHW conversion, by default NHWC is expected
+ # image is loaded as [RGB][RGB][RGB]... transposing it makes it [RRR...][GGG...][BBB...]
+ img = np.transpose(img)
+ img = img.flatten().astype(data_type) # flatten into a 1D tensor and convert to float32
+ return img
+
+
+args = eu.parse_command_line()
+
+model_filename = 'mobilenetv2-1.0.onnx'
+labels_filename = 'synset.txt'
+archive_filename = 'mobilenetv2-1.0.zip'
+labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+# Download resources
+image_filenames = eu.get_images(args.data_dir)
+
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+ assert (os.path.exists(im))
+
+# Create a network from a model file
+net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+# Load input information from the model and create input tensors
+input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+# Load output information from the model and create output tensors
+output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+output_tensors = ann.make_output_tensors([output_binding_info])
+
+# Load labels
+labels = eu.load_labels(labels_filename)
+
+# Load images and resize to expected size
+images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
diff --git a/python/pyarmnn/examples/image_classification/requirements.txt b/python/pyarmnn/examples/image_classification/requirements.txt
new file mode 100644
index 0000000000..f97e85636e
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/requirements.txt
@@ -0,0 +1,4 @@
+requests>=2.23.0
+urllib3>=1.25.8
+Pillow>=6.1.0
+numpy>=1.18.1
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
new file mode 100644
index 0000000000..229a9b6778
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# Copyright © 2020 NXP and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+import pyarmnn as ann
+import example_utils as eu
+import os
+
+args = eu.parse_command_line()
+
+# names of the files in the archive
+labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
+
+archive_url = \
+ 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
+
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename, archive_url)
+
+image_filenames = eu.get_images(args.data_dir)
+
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+ assert(os.path.exists(im))
+
+# Create a network from the model file
+net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+
+# Load input information from the model
+# tflite has all the need information in the model unlike other formats
+input_names = parser.GetSubgraphInputTensorNames(graph_id)
+assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
+
+input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+input_width = input_binding_info[1].GetShape()[1]
+input_height = input_binding_info[1].GetShape()[2]
+
+# Load output information from the model and create output tensors
+output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+assert len(output_names) == 1 # and only one output tensor
+output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
+
+# Load labels file
+labels = eu.load_labels(labels_filename)
+
+# Load images and resize to expected size
+images = eu.load_images(image_filenames, input_width, input_height)
+
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)