aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/example_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/example_utils.py')
-rw-r--r--python/pyarmnn/examples/example_utils.py189
1 files changed, 163 insertions, 26 deletions
diff --git a/python/pyarmnn/examples/example_utils.py b/python/pyarmnn/examples/example_utils.py
index 5ef30f2331..e5425dde52 100644
--- a/python/pyarmnn/examples/example_utils.py
+++ b/python/pyarmnn/examples/example_utils.py
@@ -2,27 +2,77 @@
# SPDX-License-Identifier: MIT
from urllib.parse import urlparse
-import os
from PIL import Image
+from zipfile import ZipFile
+import os
import pyarmnn as ann
import numpy as np
import requests
import argparse
import warnings
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+ """Runs inference on a set of images.
+
+ Args:
+ runtime: Arm NN runtime
+ net_id: Network ID
+ images: Loaded images to run inference on
+ labels: Loaded labels per class
+ input_binding_info: Network input information
+ output_binding_info: Network output information
+
+ Returns:
+ None
+ """
+ output_tensors = ann.make_output_tensors([output_binding_info])
+ for idx, im in enumerate(images):
+ # Create input tensors
+ input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+ # Run inference
+ print("Running inference({0}) ...".format(idx))
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+ # Process output
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ results = np.argsort(out_tensor)[::-1]
+ print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+ """Unzips a file.
+
+ Args:
+ filename(str): Name of the file
+
+ Returns:
+ None
+ """
+ with ZipFile(filename, 'r') as zip_obj:
+ zip_obj.extractall()
+
def parse_command_line(desc: str = ""):
"""Adds arguments to the script.
Args:
- desc(str): Script description.
+ desc (str): Script description
Returns:
- Namespace: Arguments to the script command.
+ Namespace: Arguments to the script command
"""
parser = argparse.ArgumentParser(description=desc)
parser.add_argument("-v", "--verbose", help="Increase output verbosity",
action="store_true")
+ parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+ action="store", default="")
+ parser.add_argument("-m", "--model-dir",
+ help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+ default="")
return parser.parse_args()
@@ -30,15 +80,14 @@ def __create_network(model_file: str, backends: list, parser=None):
"""Creates a network based on a file and parser type.
Args:
- model_file (str): Path of the model file.
+ model_file (str): Path of the model file
backends (list): List of backends to use when running inference.
parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
Returns:
- int: Network ID.
- int: Graph ID.
- IParser: TF Lite parser instance.
- IRuntime: Runtime object instance.
+ int: Network ID
+ IParser: TF Lite parser instance
+ IRuntime: Runtime object instance
"""
args = parse_command_line()
options = ann.CreationOptions()
@@ -189,7 +238,7 @@ def print_top_n(N: int, results: list, labels: list, prob: list):
print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
-def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"):
+def download_file(url: str, force: bool = False, filename: str = None):
"""Downloads a file.
Args:
@@ -197,25 +246,113 @@ def download_file(url: str, force: bool = False, filename: str = None, dest: str
force (bool): Forces to download the file even if it exists.
filename (str): Renames the file when set.
+ Raises:
+ RuntimeError: If for some reason download fails.
+
Returns:
str: Path to the downloaded file.
"""
- if filename is None: # extract filename from url when None
- filename = urlparse(url)
- filename = os.path.basename(filename.path)
-
- if str is not None:
- if not os.path.exists(dest):
- os.makedirs(dest)
- filename = os.path.join(dest, filename)
-
- print("Downloading '{0}' from '{1}' ...".format(filename, url))
- if not os.path.exists(filename) or force is True:
- r = requests.get(url)
- with open(filename, 'wb') as f:
- f.write(r.content)
- print("Finished.")
- else:
- print("File already exists.")
+ try:
+ if filename is None: # extract filename from url when None
+ filename = urlparse(url)
+ filename = os.path.basename(filename.path)
+
+ print("Downloading '{0}' from '{1}' ...".format(filename, url))
+ if not os.path.exists(filename) or force is True:
+ r = requests.get(url)
+ with open(filename, 'wb') as f:
+ f.write(r.content)
+ print("Finished.")
+ else:
+ print("File already exists.")
+ except:
+ raise RuntimeError("Unable to download file.")
return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+ """Gets model and labels.
+
+ Args:
+ model_dir(str): Folder in which model and label files can be found
+ model (str): Name of the model file
+ labels (str): Name of the labels file
+ archive (str): Name of the archive file (optional - need to provide only labels and model)
+ download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+ Returns:
+ tuple (str, str): Output label and model filenames
+ """
+ labels = os.path.join(model_dir, labels)
+ model = os.path.join(model_dir, model)
+
+ if os.path.exists(labels) and os.path.exists(model):
+ print("Found model ({0}) and labels ({1}).".format(model, labels))
+ elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+ print("Found archive ({0}). Unzipping ...".format(archive))
+ unzip_file(archive)
+ elif download_url is not None:
+ print("Model, labels or archive not found. Downloading ...".format(archive))
+ try:
+ if isinstance(download_url, str):
+ download_url = [download_url]
+ for dl in download_url:
+ archive = download_file(dl)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
+ except RuntimeError:
+ print("Unable to download file ({}).".format(archive_url))
+
+ if not os.path.exists(labels) or not os.path.exists(model):
+ raise RuntimeError("Unable to provide model and labels.")
+
+ return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+ """Lists files of a certain format in a folder.
+
+ Args:
+ folder (str): Path to the folder to search
+ formats (list): List of supported files
+
+ Returns:
+ list: A list of found files
+ """
+ files = []
+ if folder and not os.path.exists(folder):
+ print("Folder '{}' does not exist.".format(folder))
+ return files
+
+ for file in os.listdir(folder if folder else os.getcwd()):
+ for frmt in formats:
+ if file.lower().endswith(frmt):
+ files.append(os.path.join(folder, file) if folder else file)
+ break # only the format loop
+
+ return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+ """Gets image.
+
+ Args:
+ image (str): Image filename
+ image_url (str): Image url
+
+ Returns:
+ str: Output image filename
+ """
+ images = list_images(image_dir)
+ if not images and image_url is not None:
+ print("No images found. Downloading ...")
+ try:
+ images = [download_file(image_url)]
+ except RuntimeError:
+ print("Unable to download file ({0}).".format(image_url))
+
+ if not images:
+ raise RuntimeError("Unable to provide images.")
+
+ return images