diff options
Diffstat (limited to 'samples/ImageClassification/run_classifier.py')
-rw-r--r-- | samples/ImageClassification/run_classifier.py | 237 |
1 files changed, 237 insertions, 0 deletions
diff --git a/samples/ImageClassification/run_classifier.py b/samples/ImageClassification/run_classifier.py new file mode 100644 index 0000000000..1b4b9ed61e --- /dev/null +++ b/samples/ImageClassification/run_classifier.py @@ -0,0 +1,237 @@ +import argparse +from pathlib import Path +from typing import Union + +import tflite_runtime.interpreter as tflite +from PIL import Image +import numpy as np + + +def check_args(args: argparse.Namespace): + """Check the values used in the command-line have acceptable values + + args: + - args: argparse.Namespace + + returns: + - None + + raises: + - FileNotFoundError: if passed files do not exist. + - IOError: if files are of incorrect format. + """ + input_image_p = args.input_image + if not input_image_p.suffix in (".png", ".jpg", ".jpeg"): + raise IOError( + "--input_image option should point to an image file of the " + "format .jpg, .jpeg, .png" + ) + if not input_image_p.exists(): + raise FileNotFoundError("Cannot find ", input_image_p.name) + model_p = args.model_file + if not model_p.suffix == ".tflite": + raise IOError("--model_file should point to a tflite file.") + if not model_p.exists(): + raise FileNotFoundError("Cannot find ", model_p.name) + label_mapping_p = args.label_file + if not label_mapping_p.suffix == ".txt": + raise IOError("--label_file expects a .txt file.") + if not label_mapping_p.exists(): + raise FileNotFoundError("Cannot find ", label_mapping_p.name) + + # check all args given in preferred backends make sense + supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"] + if not all([backend in supported_backends for backend in args.preferred_backends]): + raise ValueError("Incorrect backends given. Please choose from "\ + "'GpuAcc', 'CpuAcc', 'CpuRef'.") + + return None + + +def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool): + """load an image and put into correct format for the tensorflow lite model + + args: + - image_path: pathlib.Path + - model_input_dims: tuple (or array-like). (height,width) + + returns: + - image: np.array + """ + height, width = model_input_dims + # load and resize image + image = Image.open(image_path).resize((width, height)) + # convert to greyscale if expected + if grayscale: + image = image.convert("LA") + + image = np.expand_dims(image, axis=0) + + return image + + +def load_delegate(delegate_path: Path, backends: list): + """load the armnn delegate. + + args: + - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so + - backends: list -> list of backends you want to use in string format + + returns: + - armnn_delegate: tflite.delegate + """ + # create a command separated string + backend_string = ",".join(backends) + # load delegate + armnn_delegate = tflite.load_delegate( + library=delegate_path, + options={"backends": backend_string, "logging-severity": "info"}, + ) + + return armnn_delegate + + +def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate): + """load a tflite model for use with the armnn delegate. + + args: + - model_path: pathlib.Path + - armnn_delegate: tflite.TfLiteDelegate + + returns: + - interpreter: tflite.Interpreter + """ + interpreter = tflite.Interpreter( + model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate] + ) + interpreter.allocate_tensors() + + return interpreter + + +def run_inference(interpreter, input_image): + """Run inference on a processed input image and return the output from + inference. + + args: + - interpreter: tflite_runtime.interpreter.Interpreter + - input_image: np.array + + returns: + - output_data: np.array + """ + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + # Test model on random input data. + interpreter.set_tensor(input_details[0]["index"], input_image) + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]["index"]) + + return output_data + + +def create_mapping(label_mapping_p): + """Creates a Python dictionary mapping an index to a label. + + label_mapping[idx] = label + + args: + - label_mapping_p: pathlib.Path + + returns: + - label_mapping: dict + """ + idx = 0 + label_mapping = {} + with open(label_mapping_p) as label_mapping_raw: + for line in label_mapping_raw: + label_mapping[idx] = line + idx = 1 + + return label_mapping + + +def process_output(output_data, label_mapping): + """Process the output tensor into a label from the labelmapping file. Takes + the index of the maximum valur from the output array. + + args: + - output_data: np.array + - label_mapping: dict + + returns: + - str: labelmapping for max index. + """ + idx = np.argmax(output_data[0]) + + return label_mapping[idx] + + +def main(args): + """Run the inference for options passed in the command line. + + args: + - args: argparse.Namespace + + returns: + - None + """ + # sanity check on args + check_args(args) + # load in the armnn delegate + armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends) + # load tflite model + interpreter = load_tf_model(args.model_file, armnn_delegate) + # get input shape for image resizing + input_shape = interpreter.get_input_details()[0]["shape"] + height, width = input_shape[1], input_shape[2] + input_shape = (height, width) + # load input image + input_image = load_image(args.input_image, input_shape, False) + # get label mapping + labelmapping = create_mapping(args.label_file) + output_tensor = run_inference(interpreter, input_image) + output_prediction = process_output(output_tensor, labelmapping) + + print("Prediction: ", output_prediction) + + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--input_image", help="File path of image file", type=Path, required=True + ) + parser.add_argument( + "--model_file", + help="File path of the model tflite file", + type=Path, + required=True, + ) + parser.add_argument( + "--label_file", + help="File path of model labelmapping file", + type=Path, + required=True, + ) + parser.add_argument( + "--delegate_path", + help="File path of ArmNN delegate file", + type=Path, + required=True, + ) + parser.add_argument( + "--preferred_backends", + help="list of backends in order of preference", + type=str, + nargs=" ", + required=False, + default=["CpuAcc", "CpuRef"], + ) + args = parser.parse_args() + + main(args) |