aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification/example_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification/example_utils.py')
-rw-r--r--python/pyarmnn/examples/image_classification/example_utils.py358
1 files changed, 358 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
new file mode 100644
index 0000000000..090ce2f5b3
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -0,0 +1,358 @@
+# Copyright © 2020 NXP and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from urllib.parse import urlparse
+from PIL import Image
+from zipfile import ZipFile
+import os
+import pyarmnn as ann
+import numpy as np
+import requests
+import argparse
+import warnings
+
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+ """Runs inference on a set of images.
+
+ Args:
+ runtime: Arm NN runtime
+ net_id: Network ID
+ images: Loaded images to run inference on
+ labels: Loaded labels per class
+ input_binding_info: Network input information
+ output_binding_info: Network output information
+
+ Returns:
+ None
+ """
+ output_tensors = ann.make_output_tensors([output_binding_info])
+ for idx, im in enumerate(images):
+ # Create input tensors
+ input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+ # Run inference
+ print("Running inference({0}) ...".format(idx))
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+ # Process output
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ results = np.argsort(out_tensor)[::-1]
+ print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+ """Unzips a file.
+
+ Args:
+ filename(str): Name of the file
+
+ Returns:
+ None
+ """
+ with ZipFile(filename, 'r') as zip_obj:
+ zip_obj.extractall()
+
+
+def parse_command_line(desc: str = ""):
+ """Adds arguments to the script.
+
+ Args:
+ desc (str): Script description
+
+ Returns:
+ Namespace: Arguments to the script command
+ """
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument("-v", "--verbose", help="Increase output verbosity",
+ action="store_true")
+ parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+ action="store", default="")
+ parser.add_argument("-m", "--model-dir",
+ help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+ default="")
+ return parser.parse_args()
+
+
+def __create_network(model_file: str, backends: list, parser=None):
+ """Creates a network based on a file and parser type.
+
+ Args:
+ model_file (str): Path of the model file
+ backends (list): List of backends to use when running inference.
+ parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
+
+ Returns:
+ int: Network ID
+ IParser: TF Lite parser instance
+ IRuntime: Runtime object instance
+ """
+ args = parse_command_line()
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+
+ if parser is None:
+ # try to determine what parser to create based on model extension
+ _, ext = os.path.splitext(model_file)
+ if ext == ".onnx":
+ parser = ann.IOnnxParser()
+ elif ext == ".tflite":
+ parser = ann.ITfLiteParser()
+ assert (parser is not None)
+
+ network = parser.CreateNetworkFromBinaryFile(model_file)
+
+ preferred_backends = []
+ for b in backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ if args.verbose:
+ for m in messages:
+ warnings.warn(m)
+
+ net_id, w = runtime.LoadNetwork(opt_network)
+ if args.verbose and w:
+ warnings.warn(w)
+
+ return net_id, parser, runtime
+
+
+def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+ """Creates a network from a tflite model file.
+
+ Args:
+ model_file (str): Path of the model file.
+ backends (list): List of backends to use when running inference.
+
+ Returns:
+ int: Network ID.
+ int: Graph ID.
+ ITFliteParser: TF Lite parser instance.
+ IRuntime: Runtime object instance.
+ """
+ net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
+ graph_id = parser.GetSubgraphCount() - 1
+
+ return net_id, graph_id, parser, runtime
+
+
+def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+ """Creates a network from an onnx model file.
+
+ Args:
+ model_file (str): Path of the model file.
+ backends (list): List of backends to use when running inference.
+
+ Returns:
+ int: Network ID.
+ IOnnxParser: ONNX parser instance.
+ IRuntime: Runtime object instance.
+ """
+ return __create_network(model_file, backends, ann.IOnnxParser())
+
+
+def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+ stddev: list):
+ """Default preprocessing image function.
+
+ Args:
+ img (PIL.Image): PIL.Image object instance.
+ width (int): Width to resize to.
+ height (int): Height to resize to.
+ data_type: Data Type to cast the image to.
+ scale (float): Scaling value.
+ mean (list): RGB mean offset.
+ stddev (list): RGB standard deviation.
+
+ Returns:
+ np.array: Resized and preprocessed image.
+ """
+ img = img.resize((width, height), Image.BILINEAR)
+ img = img.convert('RGB')
+ img = np.array(img)
+ img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
+ img = ((img / scale) - mean) / stddev
+ img = img.flatten().astype(data_type)
+ return img
+
+
+def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
+ scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ preprocess_fn=preprocess_default):
+ """Loads images, resizes and performs any additional preprocessing to run inference.
+
+ Args:
+ img (list): List of PIL.Image object instances.
+ input_width (int): Width to resize to.
+ input_height (int): Height to resize to.
+ data_type: Data Type to cast the image to.
+ scale (float): Scaling value.
+ mean (list): RGB mean offset.
+ stddev (list): RGB standard deviation.
+ preprocess_fn: Preprocessing function.
+
+ Returns:
+ np.array: Resized and preprocessed images.
+ """
+ images = []
+ for i in image_files:
+ img = Image.open(i)
+ img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
+ images.append(img)
+ return images
+
+
+def load_labels(label_file: str):
+ """Loads a labels file containing a label per line.
+
+ Args:
+ label_file (str): Labels file path.
+
+ Returns:
+ list: List of labels read from a file.
+ """
+ with open(label_file, 'r') as f:
+ labels = [l.rstrip() for l in f]
+ return labels
+ return None
+
+
+def print_top_n(N: int, results: list, labels: list, prob: list):
+ """Prints TOP-N results
+
+ Args:
+ N (int): Result count to print.
+ results (list): Top prediction indices.
+ labels (list): A list of labels for every class.
+ prob (list): A list of probabilities for every class.
+
+ Returns:
+ None
+ """
+ assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
+ for i in range(min(len(results), N)):
+ print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
+
+
+def download_file(url: str, force: bool = False, filename: str = None):
+ """Downloads a file.
+
+ Args:
+ url (str): File url.
+ force (bool): Forces to download the file even if it exists.
+ filename (str): Renames the file when set.
+
+ Raises:
+ RuntimeError: If for some reason download fails.
+
+ Returns:
+ str: Path to the downloaded file.
+ """
+ try:
+ if filename is None: # extract filename from url when None
+ filename = urlparse(url)
+ filename = os.path.basename(filename.path)
+
+ print("Downloading '{0}' from '{1}' ...".format(filename, url))
+ if not os.path.exists(filename) or force is True:
+ r = requests.get(url)
+ with open(filename, 'wb') as f:
+ f.write(r.content)
+ print("Finished.")
+ else:
+ print("File already exists.")
+ except:
+ raise RuntimeError("Unable to download file.")
+
+ return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+ """Gets model and labels.
+
+ Args:
+ model_dir(str): Folder in which model and label files can be found
+ model (str): Name of the model file
+ labels (str): Name of the labels file
+ archive (str): Name of the archive file (optional - need to provide only labels and model)
+ download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+ Returns:
+ tuple (str, str): Output label and model filenames
+ """
+ labels = os.path.join(model_dir, labels)
+ model = os.path.join(model_dir, model)
+
+ if os.path.exists(labels) and os.path.exists(model):
+ print("Found model ({0}) and labels ({1}).".format(model, labels))
+ elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+ print("Found archive ({0}). Unzipping ...".format(archive))
+ unzip_file(archive)
+ elif download_url is not None:
+ print("Model, labels or archive not found. Downloading ...".format(archive))
+ try:
+ if isinstance(download_url, str):
+ download_url = [download_url]
+ for dl in download_url:
+ archive = download_file(dl)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
+ except RuntimeError:
+ print("Unable to download file ({}).".format(archive_url))
+
+ if not os.path.exists(labels) or not os.path.exists(model):
+ raise RuntimeError("Unable to provide model and labels.")
+
+ return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+ """Lists files of a certain format in a folder.
+
+ Args:
+ folder (str): Path to the folder to search
+ formats (list): List of supported files
+
+ Returns:
+ list: A list of found files
+ """
+ files = []
+ if folder and not os.path.exists(folder):
+ print("Folder '{}' does not exist.".format(folder))
+ return files
+
+ for file in os.listdir(folder if folder else os.getcwd()):
+ for frmt in formats:
+ if file.lower().endswith(frmt):
+ files.append(os.path.join(folder, file) if folder else file)
+ break # only the format loop
+
+ return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+ """Gets image.
+
+ Args:
+ image (str): Image filename
+ image_url (str): Image url
+
+ Returns:
+ str: Output image filename
+ """
+ images = list_images(image_dir)
+ if not images and image_url is not None:
+ print("No images found. Downloading ...")
+ try:
+ images = [download_file(image_url)]
+ except RuntimeError:
+ print("Unable to download file ({0}).".format(image_url))
+
+ if not images:
+ raise RuntimeError("Unable to provide images.")
+
+ return images