diff options
Diffstat (limited to 'python/pyarmnn/examples/common/utils.py')
-rw-r--r-- | python/pyarmnn/examples/common/utils.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py new file mode 100644 index 0000000000..cf09fdefb8 --- /dev/null +++ b/python/pyarmnn/examples/common/utils.py @@ -0,0 +1,41 @@ +# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Contains helper functions that can be used across the example apps.""" + +import os +import errno +from pathlib import Path + +import numpy as np + + +def dict_labels(labels_file_path: str, include_rgb=False) -> dict: + """Creates a dictionary of labels from the input labels file. + + Args: + labels_file: Path to file containing labels to map model outputs. + include_rgb: Adds randomly generated RGB values to the values of the + dictionary. Used for plotting bounding boxes of different colours. + + Returns: + Dictionary with classification indices for keys and labels for values. + + Raises: + FileNotFoundError: + Provided `labels_file_path` does not exist. + """ + labels_file = Path(labels_file_path) + if not labels_file.is_file(): + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path + ) + + labels = {} + with open(labels_file, "r") as f: + for idx, line in enumerate(f, 0): + if include_rgb: + labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255) + else: + labels[idx] = line.strip("\n") + return labels |