aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenri Woodcock <henri.woodcock@arm.com>2021-05-19 13:41:44 +0100
committerKevin May <kevin.may@arm.com>2021-05-20 15:39:00 +0000
commit3b38eedb3cc8f1c95a9ce62ddfbe926708666e72 (patch)
tree4cd94af00d6c3562e104e4065431bc9954d7a0f0
parent5fc0fd6661f9647092deb052d052973a237bd52d (diff)
downloadarmnn-3b38eedb3cc8f1c95a9ce62ddfbe926708666e72.tar.gz
CSAF-235 Arm NN Delegate Image Classificaton Example.
* To be used in developer.arm.com Image Classification with guide. Signed-off-by: Henri Woodcock henri.woodcock@arm.com Change-Id: I3dd3b3b7ca3e579be9fd70900cff85c78f3da3f7
-rw-r--r--samples/ImageClassification/README.md135
-rw-r--r--samples/ImageClassification/requirements.txt3
-rw-r--r--samples/ImageClassification/run_classifier.py237
3 files changed, 375 insertions, 0 deletions
diff --git a/samples/ImageClassification/README.md b/samples/ImageClassification/README.md
new file mode 100644
index 0000000000..068d0c916a
--- /dev/null
+++ b/samples/ImageClassification/README.md
@@ -0,0 +1,135 @@
+# Image Classification with the Arm NN Tensorflow Lite Delegate
+
+This application demonstrates the use of the Arm NN Tensorflow Lite Delegate.
+In this application we integrate the Arm NN Tensorflow Lite Delegate into the
+TensorFlow Lite Python package.
+
+## Before You Begin
+
+This repository assumes you have built, or have downloaded the
+`libarmnnDelegate.so` and `libarmnn.so` from the GitHub releases page. You will
+also need to have built the TensorFlow Lite library from source.
+
+If you have not already installed these, please follow our guides in the ArmNN
+repository. The guide to build the delegate can be found
+[here](../../delegate/BuildGuideNative.md) and the guide to integrate the
+delegate into Python can be found
+[here](../../delegate/IntegrateDelegateIntoPython.md).
+
+
+## Getting Started
+
+Before running the application, we will first need to:
+
+- Install the required Python packages
+- Download this example
+- Download a model and corresponding label mapping
+- Download an example image
+
+1. Install required packages and Git Large File Storage (to download models
+from the Arm ML-Zoo).
+
+ ```bash
+ sudo apt-get install -y python3 python3-pip wget git git-lfs unzip
+ git lfs install
+ ```
+
+2. Clone the Arm NN repository and change directory to this example.
+
+ ```bash
+ git clone https://github.com/arm-software/armnn.git
+ cd armnn/samples/ImageClassification
+ ```
+
+3. Download your model and label mappings.
+
+ For this example we use the `MobileNetV2` model. This model can be found in
+ the Arm ML-Zoo as well as scripts to download the labels for the model.
+
+ ```bash
+ export BASEDIR=$(pwd)
+ #clone the model zoo
+ git clone https://github.com/arm-software/ml-zoo.git
+ #go to the mobilenetv2 uint8 folder
+ cd ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8
+ #generate the labelmapping
+ ./get_class_labels.sh
+ #cd back to this project folder
+ cd BASEDIR
+ #copy your model and label mapping
+ cp ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8/mobilenet_v2_1.0_224_quantized_1_default_1.tflite .
+ cp ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8 labelmappings.txt .
+ ```
+
+4. Download a test image.
+
+ ```bash
+ wget -O cat.png "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
+ ```
+
+5. Download the required Python packages.
+
+ ```bash
+ pip3 install -r requirements.txt
+ ```
+
+6. Copy over your `libtensorflow_lite_all.so` and `libarmnn.so` library files
+you built/downloaded before trying this application to the application
+folder. For example:
+
+ ```bash
+ cp path/to/tensorflow/directory/tensorflow/bazel-bin/libtensorflow_lite_all.so .
+ cp /path/to/armnn/binaries/libarmnn.so .
+ ```
+
+## Folder Structure
+
+You should now have the following folder structure:
+
+```
+.
+├── README.md
+├── run_classifier.py # script for the demo
+├── libtensorflow_lite_all.so # tflite library built from tensorflow
+├── libarmnn.so
+├── cat.png # downloaded example image
+├── mobilenet_v2_1.0_224_quantized_1_default_1.tflite #tflite model from ml-zoo
+└── labelmappings.txt # model labelmappings for output processing
+```
+
+## Run the model
+
+```bash
+python3 run_classifier.py \
+--input_image cat.png \
+--model_file mobilenet_v2_1.0_224_quantized_1_default_1.tflite \
+--label_file labelmappings.txt \
+--delegate_path /path/to/delegate/libarmnnDelegate.so.24 \
+--preferred_backends GpuAcc CpuAcc CpuRef
+```
+
+The output prediction will be printed. In this example we get:
+
+```bash
+'tabby, tabby cat'
+```
+
+## Running an inference with the Arm NN TensorFlow Lite Delegate
+
+Compared to your usual TensorFlow Lite projects, using the Arm NN TensorFlow
+Lite Delegate requires one extra step when loading in your model:
+
+```python
+import tflite_runtime.interpreter as tflite
+
+armnn_delegate = tflite.load_delegate("/path/to/delegate/libarmnnDelegate.so",
+ options={
+ "backends": "GpuAcc,CpuAcc,CpuRef",
+ "logging-severity": "info"
+ }
+)
+interpreter = tflite.Interpreter(
+ model_path="mobilenet_v2_1.0_224_quantized_1_default_1.tflite",
+ experimental_delegates=[armnn_delegate]
+)
+```
diff --git a/samples/ImageClassification/requirements.txt b/samples/ImageClassification/requirements.txt
new file mode 100644
index 0000000000..3f100b29b1
--- /dev/null
+++ b/samples/ImageClassification/requirements.txt
@@ -0,0 +1,3 @@
+numpy==1.20.2
+Pillow==8.2.0
+pybind11==2.6.2
diff --git a/samples/ImageClassification/run_classifier.py b/samples/ImageClassification/run_classifier.py
new file mode 100644
index 0000000000..1b4b9ed61e
--- /dev/null
+++ b/samples/ImageClassification/run_classifier.py
@@ -0,0 +1,237 @@
+import argparse
+from pathlib import Path
+from typing import Union
+
+import tflite_runtime.interpreter as tflite
+from PIL import Image
+import numpy as np
+
+
+def check_args(args: argparse.Namespace):
+ """Check the values used in the command-line have acceptable values
+
+ args:
+ - args: argparse.Namespace
+
+ returns:
+ - None
+
+ raises:
+ - FileNotFoundError: if passed files do not exist.
+ - IOError: if files are of incorrect format.
+ """
+ input_image_p = args.input_image
+ if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
+ raise IOError(
+ "--input_image option should point to an image file of the "
+ "format .jpg, .jpeg, .png"
+ )
+ if not input_image_p.exists():
+ raise FileNotFoundError("Cannot find ", input_image_p.name)
+ model_p = args.model_file
+ if not model_p.suffix == ".tflite":
+ raise IOError("--model_file should point to a tflite file.")
+ if not model_p.exists():
+ raise FileNotFoundError("Cannot find ", model_p.name)
+ label_mapping_p = args.label_file
+ if not label_mapping_p.suffix == ".txt":
+ raise IOError("--label_file expects a .txt file.")
+ if not label_mapping_p.exists():
+ raise FileNotFoundError("Cannot find ", label_mapping_p.name)
+
+ # check all args given in preferred backends make sense
+ supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
+ if not all([backend in supported_backends for backend in args.preferred_backends]):
+ raise ValueError("Incorrect backends given. Please choose from "\
+ "'GpuAcc', 'CpuAcc', 'CpuRef'.")
+
+ return None
+
+
+def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
+ """load an image and put into correct format for the tensorflow lite model
+
+ args:
+ - image_path: pathlib.Path
+ - model_input_dims: tuple (or array-like). (height,width)
+
+ returns:
+ - image: np.array
+ """
+ height, width = model_input_dims
+ # load and resize image
+ image = Image.open(image_path).resize((width, height))
+ # convert to greyscale if expected
+ if grayscale:
+ image = image.convert("LA")
+
+ image = np.expand_dims(image, axis=0)
+
+ return image
+
+
+def load_delegate(delegate_path: Path, backends: list):
+ """load the armnn delegate.
+
+ args:
+ - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
+ - backends: list -> list of backends you want to use in string format
+
+ returns:
+ - armnn_delegate: tflite.delegate
+ """
+ # create a command separated string
+ backend_string = ",".join(backends)
+ # load delegate
+ armnn_delegate = tflite.load_delegate(
+ library=delegate_path,
+ options={"backends": backend_string, "logging-severity": "info"},
+ )
+
+ return armnn_delegate
+
+
+def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
+ """load a tflite model for use with the armnn delegate.
+
+ args:
+ - model_path: pathlib.Path
+ - armnn_delegate: tflite.TfLiteDelegate
+
+ returns:
+ - interpreter: tflite.Interpreter
+ """
+ interpreter = tflite.Interpreter(
+ model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
+ )
+ interpreter.allocate_tensors()
+
+ return interpreter
+
+
+def run_inference(interpreter, input_image):
+ """Run inference on a processed input image and return the output from
+ inference.
+
+ args:
+ - interpreter: tflite_runtime.interpreter.Interpreter
+ - input_image: np.array
+
+ returns:
+ - output_data: np.array
+ """
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+ # Test model on random input data.
+ interpreter.set_tensor(input_details[0]["index"], input_image)
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]["index"])
+
+ return output_data
+
+
+def create_mapping(label_mapping_p):
+ """Creates a Python dictionary mapping an index to a label.
+
+ label_mapping[idx] = label
+
+ args:
+ - label_mapping_p: pathlib.Path
+
+ returns:
+ - label_mapping: dict
+ """
+ idx = 0
+ label_mapping = {}
+ with open(label_mapping_p) as label_mapping_raw:
+ for line in label_mapping_raw:
+ label_mapping[idx] = line
+ idx = 1
+
+ return label_mapping
+
+
+def process_output(output_data, label_mapping):
+ """Process the output tensor into a label from the labelmapping file. Takes
+ the index of the maximum valur from the output array.
+
+ args:
+ - output_data: np.array
+ - label_mapping: dict
+
+ returns:
+ - str: labelmapping for max index.
+ """
+ idx = np.argmax(output_data[0])
+
+ return label_mapping[idx]
+
+
+def main(args):
+ """Run the inference for options passed in the command line.
+
+ args:
+ - args: argparse.Namespace
+
+ returns:
+ - None
+ """
+ # sanity check on args
+ check_args(args)
+ # load in the armnn delegate
+ armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
+ # load tflite model
+ interpreter = load_tf_model(args.model_file, armnn_delegate)
+ # get input shape for image resizing
+ input_shape = interpreter.get_input_details()[0]["shape"]
+ height, width = input_shape[1], input_shape[2]
+ input_shape = (height, width)
+ # load input image
+ input_image = load_image(args.input_image, input_shape, False)
+ # get label mapping
+ labelmapping = create_mapping(args.label_file)
+ output_tensor = run_inference(interpreter, input_image)
+ output_prediction = process_output(output_tensor, labelmapping)
+
+ print("Prediction: ", output_prediction)
+
+ return None
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--input_image", help="File path of image file", type=Path, required=True
+ )
+ parser.add_argument(
+ "--model_file",
+ help="File path of the model tflite file",
+ type=Path,
+ required=True,
+ )
+ parser.add_argument(
+ "--label_file",
+ help="File path of model labelmapping file",
+ type=Path,
+ required=True,
+ )
+ parser.add_argument(
+ "--delegate_path",
+ help="File path of ArmNN delegate file",
+ type=Path,
+ required=True,
+ )
+ parser.add_argument(
+ "--preferred_backends",
+ help="list of backends in order of preference",
+ type=str,
+ nargs=" ",
+ required=False,
+ default=["CpuAcc", "CpuRef"],
+ )
+ args = parser.parse_args()
+
+ main(args)