diff options
Diffstat (limited to 'python/pyarmnn/examples/example_utils.py')
-rw-r--r-- | python/pyarmnn/examples/example_utils.py | 189 |
1 files changed, 163 insertions, 26 deletions
diff --git a/python/pyarmnn/examples/example_utils.py b/python/pyarmnn/examples/example_utils.py index 5ef30f2331..e5425dde52 100644 --- a/python/pyarmnn/examples/example_utils.py +++ b/python/pyarmnn/examples/example_utils.py @@ -2,27 +2,77 @@ # SPDX-License-Identifier: MIT from urllib.parse import urlparse -import os from PIL import Image +from zipfile import ZipFile +import os import pyarmnn as ann import numpy as np import requests import argparse import warnings +DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg' + + +def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info): + """Runs inference on a set of images. + + Args: + runtime: Arm NN runtime + net_id: Network ID + images: Loaded images to run inference on + labels: Loaded labels per class + input_binding_info: Network input information + output_binding_info: Network output information + + Returns: + None + """ + output_tensors = ann.make_output_tensors([output_binding_info]) + for idx, im in enumerate(images): + # Create input tensors + input_tensors = ann.make_input_tensors([input_binding_info], [im]) + + # Run inference + print("Running inference({0}) ...".format(idx)) + runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) + + # Process output + out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0] + results = np.argsort(out_tensor)[::-1] + print_top_n(5, results, labels, out_tensor) + + +def unzip_file(filename: str): + """Unzips a file. + + Args: + filename(str): Name of the file + + Returns: + None + """ + with ZipFile(filename, 'r') as zip_obj: + zip_obj.extractall() + def parse_command_line(desc: str = ""): """Adds arguments to the script. Args: - desc(str): Script description. + desc (str): Script description Returns: - Namespace: Arguments to the script command. + Namespace: Arguments to the script command """ parser = argparse.ArgumentParser(description=desc) parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true") + parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.", + action="store", default="") + parser.add_argument("-m", "--model-dir", + help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store", + default="") return parser.parse_args() @@ -30,15 +80,14 @@ def __create_network(model_file: str, backends: list, parser=None): """Creates a network based on a file and parser type. Args: - model_file (str): Path of the model file. + model_file (str): Path of the model file backends (list): List of backends to use when running inference. parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...) Returns: - int: Network ID. - int: Graph ID. - IParser: TF Lite parser instance. - IRuntime: Runtime object instance. + int: Network ID + IParser: TF Lite parser instance + IRuntime: Runtime object instance """ args = parse_command_line() options = ann.CreationOptions() @@ -189,7 +238,7 @@ def print_top_n(N: int, results: list, labels: list, prob: list): print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]])) -def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"): +def download_file(url: str, force: bool = False, filename: str = None): """Downloads a file. Args: @@ -197,25 +246,113 @@ def download_file(url: str, force: bool = False, filename: str = None, dest: str force (bool): Forces to download the file even if it exists. filename (str): Renames the file when set. + Raises: + RuntimeError: If for some reason download fails. + Returns: str: Path to the downloaded file. """ - if filename is None: # extract filename from url when None - filename = urlparse(url) - filename = os.path.basename(filename.path) - - if str is not None: - if not os.path.exists(dest): - os.makedirs(dest) - filename = os.path.join(dest, filename) - - print("Downloading '{0}' from '{1}' ...".format(filename, url)) - if not os.path.exists(filename) or force is True: - r = requests.get(url) - with open(filename, 'wb') as f: - f.write(r.content) - print("Finished.") - else: - print("File already exists.") + try: + if filename is None: # extract filename from url when None + filename = urlparse(url) + filename = os.path.basename(filename.path) + + print("Downloading '{0}' from '{1}' ...".format(filename, url)) + if not os.path.exists(filename) or force is True: + r = requests.get(url) + with open(filename, 'wb') as f: + f.write(r.content) + print("Finished.") + else: + print("File already exists.") + except: + raise RuntimeError("Unable to download file.") return filename + + +def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None): + """Gets model and labels. + + Args: + model_dir(str): Folder in which model and label files can be found + model (str): Name of the model file + labels (str): Name of the labels file + archive (str): Name of the archive file (optional - need to provide only labels and model) + download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it) + + Returns: + tuple (str, str): Output label and model filenames + """ + labels = os.path.join(model_dir, labels) + model = os.path.join(model_dir, model) + + if os.path.exists(labels) and os.path.exists(model): + print("Found model ({0}) and labels ({1}).".format(model, labels)) + elif archive is not None and os.path.exists(os.path.join(model_dir, archive)): + print("Found archive ({0}). Unzipping ...".format(archive)) + unzip_file(archive) + elif download_url is not None: + print("Model, labels or archive not found. Downloading ...".format(archive)) + try: + if isinstance(download_url, str): + download_url = [download_url] + for dl in download_url: + archive = download_file(dl) + if dl.lower().endswith(".zip"): + unzip_file(archive) + except RuntimeError: + print("Unable to download file ({}).".format(archive_url)) + + if not os.path.exists(labels) or not os.path.exists(model): + raise RuntimeError("Unable to provide model and labels.") + + return model, labels + + +def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']): + """Lists files of a certain format in a folder. + + Args: + folder (str): Path to the folder to search + formats (list): List of supported files + + Returns: + list: A list of found files + """ + files = [] + if folder and not os.path.exists(folder): + print("Folder '{}' does not exist.".format(folder)) + return files + + for file in os.listdir(folder if folder else os.getcwd()): + for frmt in formats: + if file.lower().endswith(frmt): + files.append(os.path.join(folder, file) if folder else file) + break # only the format loop + + return files + + +def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL): + """Gets image. + + Args: + image (str): Image filename + image_url (str): Image url + + Returns: + str: Output image filename + """ + images = list_images(image_dir) + if not images and image_url is not None: + print("No images found. Downloading ...") + try: + images = [download_file(image_url)] + except RuntimeError: + print("Unable to download file ({0}).".format(image_url)) + + if not images: + raise RuntimeError("Unable to provide images.") + + return images |