diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2020-11-16 14:12:11 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-11-17 12:23:56 +0000 |
commit | 145c88f851d12d2cadc2f080d232c1d5963d6e47 (patch) | |
tree | 6ae197d74782cd2c7ef8965f4b36acabc65ce453 /python/pyarmnn/examples/common/utils.py | |
parent | aa41d5d2f43790938f3a32586626be5ef55b6ca9 (diff) | |
download | armnn-145c88f851d12d2cadc2f080d232c1d5963d6e47.tar.gz |
MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'python/pyarmnn/examples/common/utils.py')
-rw-r--r-- | python/pyarmnn/examples/common/utils.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py new file mode 100644 index 0000000000..cf09fdefb8 --- /dev/null +++ b/python/pyarmnn/examples/common/utils.py @@ -0,0 +1,41 @@ +# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Contains helper functions that can be used across the example apps.""" + +import os +import errno +from pathlib import Path + +import numpy as np + + +def dict_labels(labels_file_path: str, include_rgb=False) -> dict: + """Creates a dictionary of labels from the input labels file. + + Args: + labels_file: Path to file containing labels to map model outputs. + include_rgb: Adds randomly generated RGB values to the values of the + dictionary. Used for plotting bounding boxes of different colours. + + Returns: + Dictionary with classification indices for keys and labels for values. + + Raises: + FileNotFoundError: + Provided `labels_file_path` does not exist. + """ + labels_file = Path(labels_file_path) + if not labels_file.is_file(): + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path + ) + + labels = {} + with open(labels_file, "r") as f: + for idx, line in enumerate(f, 0): + if include_rgb: + labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255) + else: + labels[idx] = line.strip("\n") + return labels |