aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉanna Ó Catháin <eanna.ocathain@arm.com>2020-11-16 14:12:11 +0000
committerJim Flynn <jim.flynn@arm.com>2020-11-17 12:23:56 +0000
commit145c88f851d12d2cadc2f080d232c1d5963d6e47 (patch)
tree6ae197d74782cd2c7ef8965f4b36acabc65ce453
parentaa41d5d2f43790938f3a32586626be5ef55b6ca9 (diff)
downloadarmnn-145c88f851d12d2cadc2f080d232c1d5963d6e47.tar.gz
MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21 Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
-rw-r--r--python/pyarmnn/examples/common/cv_utils.py (renamed from python/pyarmnn/examples/object_detection/utils.py)415
-rw-r--r--python/pyarmnn/examples/common/network_executor.py108
-rw-r--r--python/pyarmnn/examples/common/tests/conftest.py40
-rw-r--r--python/pyarmnn/examples/common/tests/context.py7
-rw-r--r--python/pyarmnn/examples/common/tests/test_network_executor.py24
-rw-r--r--python/pyarmnn/examples/common/tests/test_utils.py19
-rw-r--r--python/pyarmnn/examples/common/utils.py41
-rw-r--r--python/pyarmnn/examples/image_classification/example_utils.py20
-rw-r--r--python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py91
-rw-r--r--python/pyarmnn/examples/image_classification/requirements.txt2
-rw-r--r--python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py71
-rw-r--r--python/pyarmnn/examples/object_detection/requirements.txt5
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_file.py87
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_stream.py79
-rw-r--r--python/pyarmnn/examples/speech_recognition/README.md158
-rw-r--r--python/pyarmnn/examples/speech_recognition/__init__.py0
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_capture.py56
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_utils.py128
-rw-r--r--python/pyarmnn/examples/speech_recognition/preprocess.py260
-rw-r--r--python/pyarmnn/examples/speech_recognition/requirements.txt2
-rw-r--r--python/pyarmnn/examples/speech_recognition/run_audio_file.py94
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/conftest.py34
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/context.py13
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py17
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_decoder.py28
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py286
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npybin0 -> 4420 bytes
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt29
28 files changed, 1687 insertions, 427 deletions
diff --git a/python/pyarmnn/examples/object_detection/utils.py b/python/pyarmnn/examples/common/cv_utils.py
index 1235bf4fa6..61aa46c3d7 100644
--- a/python/pyarmnn/examples/object_detection/utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,231 +1,184 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-"""
-This file contains shared functions used in the object detection scripts for
-preprocessing data, preparing the network and postprocessing.
-"""
-
-import os
-import cv2
-import numpy as np
-import pyarmnn as ann
-
-
-def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
- """
- Creates a video writer object to write processed frames to file.
-
- Args:
- video: Video capture object, contains information about data source.
- video_path: User-specified video file path.
- output_path: Optional path to save the processed video.
-
- Returns:
- Video writer object.
- """
- _, ext = os.path.splitext(video_path)
-
- if output_path is not None:
- assert os.path.isdir(output_path)
-
- i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
- while os.path.exists(filename):
- i += 1
- filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
-
- video_writer = cv2.VideoWriter(filename=filename,
- fourcc=cv2.VideoWriter_fourcc(*'mp4v'),
- fps=int(video.get(cv2.CAP_PROP_FPS)),
- frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
- int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
- return video_writer
-
-
-def create_network(model_file: str, backends: list):
- """
- Creates a network based on the model file and a list of backends.
-
- Args:
- model_file: User-specified model file.
- backends: List of backends to optimize network.
-
- Returns:
- net_id: Unique ID of the network to run.
- runtime: Runtime context for executing inference.
- input_binding_info: Contains essential information about the model input.
- output_binding_info: Used to map output tensor and its memory.
- """
- if not os.path.exists(model_file):
- raise FileNotFoundError(f'Model file not found for: {model_file}')
-
- # Determine which parser to create based on model file extension
- parser = None
- _, ext = os.path.splitext(model_file)
- if ext == '.tflite':
- parser = ann.ITfLiteParser()
- elif ext == '.pb':
- parser = ann.ITfParser()
- elif ext == '.onnx':
- parser = ann.IOnnxParser()
- assert (parser is not None)
- network = parser.CreateNetworkFromBinaryFile(model_file)
-
- # Specify backends to optimize network
- preferred_backends = []
- for b in backends:
- preferred_backends.append(ann.BackendId(b))
-
- # Select appropriate device context and optimize the network for that device
- options = ann.CreationOptions()
- runtime = ann.IRuntime(options)
- opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
- ann.OptimizerOptions())
- print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
- f'Optimization warnings: {messages}')
-
- # Load the optimized network onto the Runtime device
- net_id, _ = runtime.LoadNetwork(opt_network)
-
- # Get input and output binding information
- graph_id = parser.GetSubgraphCount() - 1
- input_names = parser.GetSubgraphInputTensorNames(graph_id)
- input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
- output_names = parser.GetSubgraphOutputTensorNames(graph_id)
- output_binding_info = []
- for output_name in output_names:
- outBindInfo = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
- output_binding_info.append(outBindInfo)
- return net_id, runtime, input_binding_info, output_binding_info
-
-
-def dict_labels(labels_file: str):
- """
- Creates a labels dictionary from the input labels file.
-
- Args:
- labels_file: Default or user-specified file containing the model output labels.
-
- Returns:
- A dictionary keyed on the classification index with values corresponding to
- labels and randomly generated RGB colors.
- """
- labels_dict = {}
- with open(labels_file, 'r') as labels:
- for index, line in enumerate(labels, 0):
- labels_dict[index] = line.strip('\n'), tuple(np.random.random(size=3) * 255)
- return labels_dict
-
-
-def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
- """
- Resizes frame while maintaining aspect ratio, padding any empty space.
-
- Args:
- frame: Captured frame.
- input_binding_info: Contains shape of model input layer.
-
- Returns:
- Frame resized to the size of model input layer.
- """
- aspect_ratio = frame.shape[1] / frame.shape[0]
- model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
-
- if aspect_ratio >= 1.0:
- new_height, new_width = int(model_width / aspect_ratio), model_width
- b_padding, r_padding = model_height - new_height, 0
- else:
- new_height, new_width = model_height, int(model_height * aspect_ratio)
- b_padding, r_padding = 0, model_width - new_width
-
- # Resize and pad any empty space
- frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
- frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
- borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
- return frame
-
-
-def preprocess(frame: np.ndarray, input_binding_info: tuple):
- """
- Takes a frame, resizes, swaps channels and converts data type to match
- model input layer. The converted frame is wrapped in a const tensor
- and bound to the input tensor.
-
- Args:
- frame: Captured frame from video.
- input_binding_info: Contains shape and data type of model input layer.
-
- Returns:
- Input tensor.
- """
- # Swap channels and resize frame to model resolution
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
-
- # Expand dimensions and convert data type to match model input
- data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8
- resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
- assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
-
- input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
- return input_tensors
-
-
-def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> np.ndarray:
- """
- Executes inference for the loaded network.
-
- Args:
- input_tensors: The input frame tensor.
- output_tensors: The output tensor from output node.
- runtime: Runtime context for executing inference.
- net_id: Unique ID of the network to run.
-
- Returns:
- Inference results as a list of ndarrays.
- """
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
- output = ann.workload_tensors_to_ndarray(output_tensors)
- return output
-
-
-def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
- """
- Draws bounding boxes around detected objects and adds a label and confidence score.
-
- Args:
- frame: The original captured frame from video source.
- detections: A list of detected objects in the form [class, [box positions], confidence].
- resize_factor: Resizing factor to scale box coordinates to output frame size.
- labels: Dictionary of labels and colors keyed on the classification index.
- """
- for detection in detections:
- class_idx, box, confidence = [d for d in detection]
- label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
-
- # Obtain frame size and resized bounding box positions
- frame_height, frame_width = frame.shape[:2]
- x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
-
- # Ensure box stays within the frame
- x_min, y_min = max(0, x_min), max(0, y_min)
- x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
-
- # Draw bounding box around detected object
- cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
-
- # Create label for detected object class
- label = f'{label} {confidence * 100:.1f}%'
- label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
-
- # Make sure label always stays on-screen
- x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
-
- lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
- lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
- lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
-
- # Add label and confidence value
- cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
- cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
- label_color, 1, cv2.LINE_AA)
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+This file contains helper functions for reading video/image data and
+ pre/postprocessing of video/image data using OpenCV.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+import pyarmnn as ann
+
+
+def preprocess(frame: np.ndarray, input_binding_info: tuple):
+ """
+ Takes a frame, resizes, swaps channels and converts data type to match
+ model input layer. The converted frame is wrapped in a const tensor
+ and bound to the input tensor.
+
+ Args:
+ frame: Captured frame from video.
+ input_binding_info: Contains shape and data type of model input layer.
+
+ Returns:
+ Input tensor.
+ """
+ # Swap channels and resize frame to model resolution
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
+
+ # Expand dimensions and convert data type to match model input
+ data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8
+ resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
+ assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
+
+ input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
+ return input_tensors
+
+
+def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
+ """
+ Resizes frame while maintaining aspect ratio, padding any empty space.
+
+ Args:
+ frame: Captured frame.
+ input_binding_info: Contains shape of model input layer.
+
+ Returns:
+ Frame resized to the size of model input layer.
+ """
+ aspect_ratio = frame.shape[1] / frame.shape[0]
+ model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+
+ if aspect_ratio >= 1.0:
+ new_height, new_width = int(model_width / aspect_ratio), model_width
+ b_padding, r_padding = model_height - new_height, 0
+ else:
+ new_height, new_width = model_height, int(model_height * aspect_ratio)
+ b_padding, r_padding = 0, model_width - new_width
+
+ # Resize and pad any empty space
+ frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
+ frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
+ borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
+ return frame
+
+
+def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
+ """
+ Creates a video writer object to write processed frames to file.
+
+ Args:
+ video: Video capture object, contains information about data source.
+ video_path: User-specified video file path.
+ output_path: Optional path to save the processed video.
+
+ Returns:
+ Video writer object.
+ """
+ _, ext = os.path.splitext(video_path)
+
+ if output_path is not None:
+ assert os.path.isdir(output_path)
+
+ i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
+ while os.path.exists(filename):
+ i += 1
+ filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
+
+ video_writer = cv2.VideoWriter(filename=filename,
+ fourcc=get_source_encoding_int(video),
+ fps=int(video.get(cv2.CAP_PROP_FPS)),
+ frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
+ int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
+ return video_writer
+
+
+def init_video_file_capture(video_path: str, output_path: str):
+ """
+ Creates a video capture object from a video file.
+
+ Args:
+ video_path: User-specified video file path.
+ output_path: Optional path to save the processed video.
+
+ Returns:
+ Video capture object to capture frames, video writer object to write processed
+ frames to file, plus total frame count of video source to iterate through.
+ """
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f'Video file not found for: {video_path}')
+ video = cv2.VideoCapture(video_path)
+ if not video.isOpened:
+ raise RuntimeError(f'Failed to open video capture from file: {video_path}')
+
+ video_writer = create_video_writer(video, video_path, output_path)
+ iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
+ return video, video_writer, iter_frame_count
+
+
+def init_video_stream_capture(video_source: int):
+ """
+ Creates a video capture object from a device.
+
+ Args:
+ video_source: Device index used to read video stream.
+
+ Returns:
+ Video capture object used to capture frames from a video stream.
+ """
+ video = cv2.VideoCapture(video_source)
+ if not video.isOpened:
+ raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
+ print('Processing video stream. Press \'Esc\' key to exit the demo.')
+ return video
+
+
+def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
+ """
+ Draws bounding boxes around detected objects and adds a label and confidence score.
+
+ Args:
+ frame: The original captured frame from video source.
+ detections: A list of detected objects in the form [class, [box positions], confidence].
+ resize_factor: Resizing factor to scale box coordinates to output frame size.
+ labels: Dictionary of labels and colors keyed on the classification index.
+ """
+ for detection in detections:
+ class_idx, box, confidence = [d for d in detection]
+ label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
+
+ # Obtain frame size and resized bounding box positions
+ frame_height, frame_width = frame.shape[:2]
+ x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
+
+ # Ensure box stays within the frame
+ x_min, y_min = max(0, x_min), max(0, y_min)
+ x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
+
+ # Draw bounding box around detected object
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
+
+ # Create label for detected object class
+ label = f'{label} {confidence * 100:.1f}%'
+ label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
+
+ # Make sure label always stays on-screen
+ x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
+
+ lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
+ lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
+ lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
+
+ # Add label and confidence value
+ cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
+ cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
+ label_color, 1, cv2.LINE_AA)
+
+
+def get_source_encoding_int(video_capture):
+ return int(video_capture.get(cv2.CAP_PROP_FOURCC))
diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py
new file mode 100644
index 0000000000..6e2c53c43d
--- /dev/null
+++ b/python/pyarmnn/examples/common/network_executor.py
@@ -0,0 +1,108 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+from typing import List, Tuple
+
+import pyarmnn as ann
+import numpy as np
+
+
+def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+ input_names:
+ output_names:
+
+ Returns:
+ net_id: Unique ID of the network to run.
+ runtime: Runtime context for executing inference.
+ input_binding_info: Contains essential information about the model input.
+ output_binding_info: Used to map output tensor and its memory.
+ """
+ if not os.path.exists(model_file):
+ raise FileNotFoundError(f'Model file not found for: {model_file}')
+
+ _, ext = os.path.splitext(model_file)
+ if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ network = parser.CreateNetworkFromBinaryFile(model_file)
+
+ # Specify backends to optimize network
+ preferred_backends = []
+ for b in backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ # Select appropriate device context and optimize the network for that device
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
+ f'Optimization warnings: {messages}')
+
+ # Load the optimized network onto the Runtime device
+ net_id, _ = runtime.LoadNetwork(opt_network)
+
+ # Get input and output binding information
+ graph_id = parser.GetSubgraphCount() - 1
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ output_binding_info = []
+ for output_name in output_names:
+ out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+ output_binding_info.append(out_bind_info)
+ return net_id, runtime, input_binding_info, output_binding_info
+
+
+def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_tensors: The input frame tensor.
+ output_tensors: The output tensor from output node.
+ runtime: Runtime context for executing inference.
+ net_id: Unique ID of the network to run.
+
+ Returns:
+ list: Inference results as a list of ndarrays.
+ """
+ runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+ output = ann.workload_tensors_to_ndarray(output_tensors)
+ return output
+
+
+class ArmnnNetworkExecutor:
+
+ def __init__(self, model_file: str, backends: list):
+ """
+ Creates an inference executor for a given network and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+ """
+ self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file,
+ backends)
+ self.output_tensors = ann.make_output_tensors(self.output_binding_info)
+
+ def run(self, input_tensors: list) -> List[np.ndarray]:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_tensors: The input frame tensor.
+
+ Returns:
+ list: Inference results as a list of ndarrays.
+ """
+ return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id)
diff --git a/python/pyarmnn/examples/common/tests/conftest.py b/python/pyarmnn/examples/common/tests/conftest.py
new file mode 100644
index 0000000000..5e027a0125
--- /dev/null
+++ b/python/pyarmnn/examples/common/tests/conftest.py
@@ -0,0 +1,40 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import ntpath
+
+import urllib.request
+import zipfile
+
+import pytest
+
+script_dir = os.path.dirname(__file__)
+@pytest.fixture(scope="session")
+def test_data_folder(request):
+ """
+ This fixture returns path to folder with shared test resources among all tests
+ """
+
+ data_dir = os.path.join(script_dir, "testdata")
+ if not os.path.exists(data_dir):
+ os.mkdir(data_dir)
+
+ files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
+ "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
+ "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"
+ ]
+
+ for file in files_to_download:
+ path, filename = ntpath.split(file)
+ file_path = os.path.join(data_dir, filename)
+ if not os.path.exists(file_path):
+ print("\nDownloading test file: " + file_path + "\n")
+ urllib.request.urlretrieve(file, file_path)
+
+ # Any unzipping needed, and moving around of files
+ with zipfile.ZipFile(os.path.join(data_dir, "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"), 'r') as zip_ref:
+ zip_ref.extractall(data_dir)
+
+ return data_dir
diff --git a/python/pyarmnn/examples/common/tests/context.py b/python/pyarmnn/examples/common/tests/context.py
new file mode 100644
index 0000000000..72246c03bf
--- /dev/null
+++ b/python/pyarmnn/examples/common/tests/context.py
@@ -0,0 +1,7 @@
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+
+import cv_utils
+import network_executor
+import utils
diff --git a/python/pyarmnn/examples/common/tests/test_network_executor.py b/python/pyarmnn/examples/common/tests/test_network_executor.py
new file mode 100644
index 0000000000..e27b382078
--- /dev/null
+++ b/python/pyarmnn/examples/common/tests/test_network_executor.py
@@ -0,0 +1,24 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+import cv2
+
+from context import network_executor
+from context import cv_utils
+
+
+def test_execute_network(test_data_folder):
+ model_path = os.path.join(test_data_folder, "detect.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+
+ executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
+ img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+ input_tensors = cv_utils.preprocess(img, executor.input_binding_info)
+
+ output_result = executor.run(input_tensors)
+
+ # Ensure it detects a person
+ classes = output_result[1]
+ assert classes[0][0] == 0
diff --git a/python/pyarmnn/examples/common/tests/test_utils.py b/python/pyarmnn/examples/common/tests/test_utils.py
new file mode 100644
index 0000000000..28d68ea235
--- /dev/null
+++ b/python/pyarmnn/examples/common/tests/test_utils.py
@@ -0,0 +1,19 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+from context import cv_utils
+from context import utils
+
+
+def test_get_source_encoding(test_data_folder):
+ video_file = os.path.join(test_data_folder, "Megamind.avi")
+ video, video_writer, frame_count = cv_utils.init_video_file_capture(video_file, "/tmp")
+ assert cv_utils.get_source_encoding_int(video) == 1145656920
+
+
+def test_read_existing_labels_file(test_data_folder):
+ label_file = os.path.join(test_data_folder, "labelmap.txt")
+ labels_map = utils.dict_labels(label_file)
+ assert labels_map is not None
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
new file mode 100644
index 0000000000..cf09fdefb8
--- /dev/null
+++ b/python/pyarmnn/examples/common/utils.py
@@ -0,0 +1,41 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Contains helper functions that can be used across the example apps."""
+
+import os
+import errno
+from pathlib import Path
+
+import numpy as np
+
+
+def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
+ """Creates a dictionary of labels from the input labels file.
+
+ Args:
+ labels_file: Path to file containing labels to map model outputs.
+ include_rgb: Adds randomly generated RGB values to the values of the
+ dictionary. Used for plotting bounding boxes of different colours.
+
+ Returns:
+ Dictionary with classification indices for keys and labels for values.
+
+ Raises:
+ FileNotFoundError:
+ Provided `labels_file_path` does not exist.
+ """
+ labels_file = Path(labels_file_path)
+ if not labels_file.is_file():
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path
+ )
+
+ labels = {}
+ with open(labels_file, "r") as f:
+ for idx, line in enumerate(f, 0):
+ if include_rgb:
+ labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255)
+ else:
+ labels[idx] = line.strip("\n")
+ return labels
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
index 090ce2f5b3..f0ba91e981 100644
--- a/python/pyarmnn/examples/image_classification/example_utils.py
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -38,7 +38,8 @@ def run_inference(runtime, net_id, images, labels, input_binding_info, output_bi
runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
# Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ # output tensor has a shape (1, 1001)
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
results = np.argsort(out_tensor)[::-1]
print_top_n(5, results, labels, out_tensor)
@@ -121,7 +122,7 @@ def __create_network(model_file: str, backends: list, parser=None):
return net_id, parser, runtime
-def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from a tflite model file.
Args:
@@ -140,7 +141,7 @@ def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']
return net_id, graph_id, parser, runtime
-def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from an onnx model file.
Args:
@@ -181,7 +182,7 @@ def preprocess_default(img: Image, width: int, height: int, data_type, scale: fl
def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
- scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
preprocess_fn=preprocess_default):
"""Loads images, resizes and performs any additional preprocessing to run inference.
@@ -218,7 +219,6 @@ def load_labels(label_file: str):
with open(label_file, 'r') as f:
labels = [l.rstrip() for l in f]
return labels
- return None
def print_top_n(N: int, results: list, labels: list, prob: list):
@@ -299,10 +299,10 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
download_url = [download_url]
for dl in download_url:
archive = download_file(dl)
- if dl.lower().endswith(".zip"):
- unzip_file(archive)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
except RuntimeError:
- print("Unable to download file ({}).".format(archive_url))
+ print("Unable to download file ({}).".format(download_url))
if not os.path.exists(labels) or not os.path.exists(model):
raise RuntimeError("Unable to provide model and labels.")
@@ -310,7 +310,7 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
return model, labels
-def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
"""Lists files of a certain format in a folder.
Args:
@@ -338,7 +338,7 @@ def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
"""Gets image.
Args:
- image (str): Image filename
+ image_dir (str): Image filename
image_url (str): Image url
Returns:
diff --git a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
index 9e95f76dcc..be28b585ba 100644
--- a/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
+++ b/python/pyarmnn/examples/image_classification/onnx_mobilenetv2.py
@@ -44,48 +44,49 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
return img
-args = eu.parse_command_line()
-
-model_filename = 'mobilenetv2-1.0.onnx'
-labels_filename = 'synset.txt'
-archive_filename = 'mobilenetv2-1.0.zip'
-labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
-model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
-
-# Download resources
-image_filenames = eu.get_images(args.data_dir)
-
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename,
- [model_url, labels_url])
-
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert (os.path.exists(im))
-
-# Create a network from a model file
-net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
-# Load input information from the model and create input tensors
-input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
-# Load output information from the model and create output tensors
-output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
-output_tensors = ann.make_output_tensors([output_binding_info])
-
-# Load labels
-labels = eu.load_labels(labels_filename)
-
-# Load images and resize to expected size
-images = eu.load_images(image_filenames,
- 224, 224,
- np.float32,
- 255.0,
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225],
- preprocess_onnx)
-
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+if __name__ == "__main__":
+ args = eu.parse_command_line()
+
+ model_filename = 'mobilenetv2-1.0.onnx'
+ labels_filename = 'synset.txt'
+ archive_filename = 'mobilenetv2-1.0.zip'
+ labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+ model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+ # Download resources
+ image_filenames = eu.get_images(args.data_dir)
+
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename,
+ [model_url, labels_url])
+
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert (os.path.exists(im))
+
+ # Create a network from a model file
+ net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+ # Load input information from the model and create input tensors
+ input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+ # Load output information from the model and create output tensors
+ output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+ output_tensors = ann.make_output_tensors([output_binding_info])
+
+ # Load labels
+ labels = eu.load_labels(labels_filename)
+
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames,
+ 224, 224,
+ np.float32,
+ 255.0,
+ [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225],
+ preprocess_onnx)
+
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
diff --git a/python/pyarmnn/examples/image_classification/requirements.txt b/python/pyarmnn/examples/image_classification/requirements.txt
index f97e85636e..289a2b521a 100644
--- a/python/pyarmnn/examples/image_classification/requirements.txt
+++ b/python/pyarmnn/examples/image_classification/requirements.txt
@@ -1,4 +1,4 @@
requests>=2.23.0
urllib3>=1.25.8
Pillow>=6.1.0
-numpy>=1.18.1
+numpy>=1.19.2
diff --git a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
index 229a9b6778..6b35f63a00 100644
--- a/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
+++ b/python/pyarmnn/examples/image_classification/tflite_mobilenetv1_quantized.py
@@ -2,54 +2,53 @@
# Copyright © 2020 NXP and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
-import numpy as np
-import pyarmnn as ann
import example_utils as eu
import os
-args = eu.parse_command_line()
+if __name__ == "__main__":
+ args = eu.parse_command_line()
-# names of the files in the archive
-labels_filename = 'labels_mobilenet_quant_v1_224.txt'
-model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
-archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
+ # names of the files in the archive
+ labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+ model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+ archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
-archive_url = \
- 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
+ archive_url = \
+ 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
-model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
- archive_filename, archive_url)
+ model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+ archive_filename, archive_url)
-image_filenames = eu.get_images(args.data_dir)
+ image_filenames = eu.get_images(args.data_dir)
-# all 3 resources must exist to proceed further
-assert os.path.exists(labels_filename)
-assert os.path.exists(model_filename)
-assert image_filenames
-for im in image_filenames:
- assert(os.path.exists(im))
+ # all 3 resources must exist to proceed further
+ assert os.path.exists(labels_filename)
+ assert os.path.exists(model_filename)
+ assert image_filenames
+ for im in image_filenames:
+ assert(os.path.exists(im))
-# Create a network from the model file
-net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+ # Create a network from the model file
+ net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
-# Load input information from the model
-# tflite has all the need information in the model unlike other formats
-input_names = parser.GetSubgraphInputTensorNames(graph_id)
-assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
+ # Load input information from the model
+ # tflite has all the need information in the model unlike other formats
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
-input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
-input_width = input_binding_info[1].GetShape()[1]
-input_height = input_binding_info[1].GetShape()[2]
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+ input_width = input_binding_info[1].GetShape()[1]
+ input_height = input_binding_info[1].GetShape()[2]
-# Load output information from the model and create output tensors
-output_names = parser.GetSubgraphOutputTensorNames(graph_id)
-assert len(output_names) == 1 # and only one output tensor
-output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
+ # Load output information from the model and create output tensors
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ assert len(output_names) == 1 # and only one output tensor
+ output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
-# Load labels file
-labels = eu.load_labels(labels_filename)
+ # Load labels file
+ labels = eu.load_labels(labels_filename)
-# Load images and resize to expected size
-images = eu.load_images(image_filenames, input_width, input_height)
+ # Load images and resize to expected size
+ images = eu.load_images(image_filenames, input_width, input_height)
-eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
+ eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
diff --git a/python/pyarmnn/examples/object_detection/requirements.txt b/python/pyarmnn/examples/object_detection/requirements.txt
index 7cc6379eb9..717a536a0e 100644
--- a/python/pyarmnn/examples/object_detection/requirements.txt
+++ b/python/pyarmnn/examples/object_detection/requirements.txt
@@ -1,3 +1,2 @@
-argparse>=1.4.0
-numpy>=1.19.0
-tqdm>=4.47.0 \ No newline at end of file
+numpy>=1.19.2
+tqdm>=4.47.0
diff --git a/python/pyarmnn/examples/object_detection/run_video_file.py b/python/pyarmnn/examples/object_detection/run_video_file.py
index 4f06eb184d..fc3e214721 100644
--- a/python/pyarmnn/examples/object_detection/run_video_file.py
+++ b/python/pyarmnn/examples/object_detection/run_video_file.py
@@ -7,55 +7,19 @@ bounding boxes and labels around detected objects, and saves the processed video
"""
import os
+import sys
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
import cv2
-import pyarmnn as ann
from tqdm import tqdm
from argparse import ArgumentParser
from ssd import ssd_processing, ssd_resize_factor
from yolo import yolo_processing, yolo_resize_factor
-from utils import create_video_writer, create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
-
-
-parser = ArgumentParser()
-parser.add_argument('--video_file_path', required=True, type=str,
- help='Path to the video file to run object detection on')
-parser.add_argument('--model_file_path', required=True, type=str,
- help='Path to the Object Detection model to use')
-parser.add_argument('--model_name', required=True, type=str,
- help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
-parser.add_argument('--label_path', type=str,
- help='Path to the labelset for the provided model file')
-parser.add_argument('--output_video_file_path', type=str,
- help='Path to the output video file with detections added in')
-parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
- help='Takes the preferred backends in preference order, separated by whitespace, '
- 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
- 'Defaults to [CpuAcc, CpuRef]')
-args = parser.parse_args()
-
-
-def init_video(video_path: str, output_path: str):
- """
- Creates a video capture object from a video file.
-
- Args:
- video_path: User-specified video file path.
- output_path: Optional path to save the processed video.
-
- Returns:
- Video capture object to capture frames, video writer object to write processed
- frames to file, plus total frame count of video source to iterate through.
- """
- if not os.path.exists(video_path):
- raise FileNotFoundError(f'Video file not found for: {video_path}')
- video = cv2.VideoCapture(video_path)
- if not video.isOpened:
- raise RuntimeError(f'Failed to open video capture from file: {video_path}')
-
- video_writer = create_video_writer(video, video_path, output_path)
- iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
- return video, video_writer, iter_frame_count
+from utils import dict_labels
+from cv_utils import init_video_file_capture, preprocess, draw_bounding_boxes
+from network_executor import ArmnnNetworkExecutor
def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
@@ -72,30 +36,29 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
Model labels, decoding and processing functions.
"""
if model_name == 'ssd_mobilenet_v1':
- labels = os.path.join('ssd_labels.txt')
+ labels = os.path.join(script_dir, 'ssd_labels.txt')
return labels, ssd_processing, ssd_resize_factor(video)
elif model_name == 'yolo_v3_tiny':
- labels = os.path.join('yolo_labels.txt')
+ labels = os.path.join(script_dir, 'yolo_labels.txt')
return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
else:
raise ValueError(f'{model_name} is not a valid model name')
def main(args):
- video, video_writer, frame_count = init_video(args.video_file_path, args.output_video_file_path)
- net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
- args.preferred_backends)
- output_tensors = ann.make_output_tensors(output_binding_info)
- labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
- labels = dict_labels(labels if args.label_path is None else args.label_path)
+ video, video_writer, frame_count = init_video_file_capture(args.video_file_path, args.output_video_file_path)
+
+ executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+ labels, process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
+ labels = dict_labels(labels if args.label_path is None else args.label_path, include_rgb=True)
for _ in tqdm(frame_count, desc='Processing frames'):
frame_present, frame = video.read()
if not frame_present:
continue
- input_tensors = preprocess(frame, input_binding_info)
- inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
- detections = process_output(inference_output)
+ input_tensors = preprocess(frame, executor.input_binding_info)
+ output_result = executor.run(input_tensors)
+ detections = process_output(output_result)
draw_bounding_boxes(frame, detections, resize_factor, labels)
video_writer.write(frame)
print('Finished processing frames')
@@ -103,4 +66,20 @@ def main(args):
if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--video_file_path', required=True, type=str,
+ help='Path to the video file to run object detection on')
+ parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+ parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+ parser.add_argument('--label_path', type=str,
+ help='Path to the labelset for the provided model file')
+ parser.add_argument('--output_video_file_path', type=str,
+ help='Path to the output video file with detections added in')
+ parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+ args = parser.parse_args()
main(args)
diff --git a/python/pyarmnn/examples/object_detection/run_video_stream.py b/python/pyarmnn/examples/object_detection/run_video_stream.py
index 94dc6c8b13..9a303e8129 100644
--- a/python/pyarmnn/examples/object_detection/run_video_stream.py
+++ b/python/pyarmnn/examples/object_detection/run_video_stream.py
@@ -8,47 +8,18 @@ and displays a window with the latest processed frame.
"""
import os
+import sys
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
import cv2
-import pyarmnn as ann
-from tqdm import tqdm
from argparse import ArgumentParser
from ssd import ssd_processing, ssd_resize_factor
from yolo import yolo_processing, yolo_resize_factor
-from utils import create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
-
-
-parser = ArgumentParser()
-parser.add_argument('--video_source', type=int, default=0,
- help='Device index to access video stream. Defaults to primary device camera at index 0')
-parser.add_argument('--model_file_path', required=True, type=str,
- help='Path to the Object Detection model to use')
-parser.add_argument('--model_name', required=True, type=str,
- help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
-parser.add_argument('--label_path', type=str,
- help='Path to the labelset for the provided model file')
-parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
- help='Takes the preferred backends in preference order, separated by whitespace, '
- 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
- 'Defaults to [CpuAcc, CpuRef]')
-args = parser.parse_args()
-
-
-def init_video(video_source: int):
- """
- Creates a video capture object from a device.
-
- Args:
- video_source: Device index used to read video stream.
-
- Returns:
- Video capture object used to capture frames from a video stream.
- """
- video = cv2.VideoCapture(video_source)
- if not video.isOpened:
- raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
- print('Processing video stream. Press \'Esc\' key to exit the demo.')
- return video
+from utils import dict_labels
+from cv_utils import init_video_stream_capture, preprocess, draw_bounding_boxes
+from network_executor import ArmnnNetworkExecutor
def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
@@ -65,31 +36,31 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
Model labels, decoding and processing functions.
"""
if model_name == 'ssd_mobilenet_v1':
- labels = os.path.join('ssd_labels.txt')
+ labels = os.path.join(script_dir, 'ssd_labels.txt')
return labels, ssd_processing, ssd_resize_factor(video)
elif model_name == 'yolo_v3_tiny':
- labels = os.path.join('yolo_labels.txt')
+ labels = os.path.join(script_dir, 'yolo_labels.txt')
return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
else:
raise ValueError(f'{model_name} is not a valid model name')
def main(args):
- video = init_video(args.video_source)
- net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
- args.preferred_backends)
- output_tensors = ann.make_output_tensors(output_binding_info)
- labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
- labels = dict_labels(labels if args.label_path is None else args.label_path)
+ video = init_video_stream_capture(args.video_source)
+ executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+
+ labels, process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
+ labels = dict_labels(labels if args.label_path is None else args.label_path, include_rgb=True)
while True:
frame_present, frame = video.read()
frame = cv2.flip(frame, 1) # Horizontally flip the frame
if not frame_present:
raise RuntimeError('Error reading frame from video stream')
- input_tensors = preprocess(frame, input_binding_info)
- inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
- detections = process_output(inference_output)
+ input_tensors = preprocess(frame, executor.input_binding_info)
+ print("Running inference...")
+ output_result = executor.run(input_tensors)
+ detections = process_output(output_result)
draw_bounding_boxes(frame, detections, resize_factor, labels)
cv2.imshow('PyArmNN Object Detection Demo', frame)
if cv2.waitKey(1) == 27:
@@ -99,4 +70,18 @@ def main(args):
if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--video_source', type=int, default=0,
+ help='Device index to access video stream. Defaults to primary device camera at index 0')
+ parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+ parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+ parser.add_argument('--label_path', type=str,
+ help='Path to the labelset for the provided model file')
+ parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+ args = parser.parse_args()
main(args)
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
new file mode 100644
index 0000000000..10a583f123
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -0,0 +1,158 @@
+# Automatic Speech Recognition with PyArmNN
+
+This sample application guides the user to perform automatic speech recognition (ASR) with PyArmNN API.
+
+## Prerequisites
+
+### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'22.0.0'
+```
+
+### Dependencies
+
+Install the PortAudio package:
+
+```bash
+$ sudo apt-get install libsndfile1 libportaudio2
+```
+
+Install the required Python modules:
+
+```bash
+$ pip install -r requirements.txt
+```
+
+## Performing Automatic Speech Recognition
+
+### Processing Audio Files
+
+To run ASR on an audio file, use the following command:
+
+```bash
+$ python run_audio_file.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model> --labels_file_path <path/to/your_labels>
+```
+
+You may also add the optional flags:
+
+* `--preferred_backends`
+
+ * Takes the preferred backends in preference order, separated by whitespace. For example, passing in "CpuAcc CpuRef" will be read as list ["CpuAcc", "CpuRef"] (defaults to this list)
+
+ * CpuAcc represents the CPU backend
+
+ * GpuAcc represents the GPU backend
+
+ * CpuRef represents the CPU reference kernels
+
+* `--help` prints all available options to screen
+
+## Application Overview
+
+1. [Initialization](#initialization)
+
+2. [Creating a network](#creating-a-network)
+
+3. [Automatic speech recognition pipeline](#automatic-speech-recognition-pipeline)
+
+### Initialization
+
+The application parses the supplied user arguments and loads the audio file into the `AudioCapture` class, which initialises the audio source and sets sampling parameters required by the model with `ModelParams` class.
+
+`AudioCapture` helps us to capture chunks of audio data from the source. With ASR from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size.
+
+To interpret the inference result of the loaded network, the application must load the labels that are associated with the model. The `dict_labels()` function creates a dictionary that is keyed on the classification index at the output node of the model. The values of the dictionary are the corresponding characters.
+
+### Creating a network
+
+A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite, TF, and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
+
+Arm NN supports optimized execution on multiple CPU, GPU, and Ethos-N devices. Before executing a graph, the application must select the appropriate device context by using `IRuntime()` to create a runtime context with default options. We can optimize the imported graph by specifying a list of backends in order of preference and implementing backend-specific optimizations, identified by a unique string, for example CpuAcc, GpuAcc, CpuRef represent the accelerated CPU and GPU backends and the CPU reference kernels respectively.
+
+Arm NN splits the entire graph into subgraphs based on these backends. Each subgraph is then optimized, and the corresponding subgraph in the original graph is substituted with its optimized version.
+
+The `Optimize()` function optimizes the graph for inference, then `LoadNetwork()` loads the optimized network onto the compute device. The `LoadNetwork()` function also creates the backend-specific workloads for the layers and a backend-specific workload factory.
+
+Parsers extract the input information for the network. The `GetSubgraphInputTensorNames()` function extracts all the input names and the `GetNetworkInputBindingInfo()` function obtains the input binding information of the graph. The input binding information contains all the essential information about the input. This information is a tuple consisting of integer identifiers for bindable layers and tensor information (data type, quantization info, dimension count, total elements).
+
+Similarly, we can get the output binding information for an output layer by using the parser to retrieve output tensor names and calling the `GetNetworkOutputBindingInfo()` function
+
+For this application, the main point of contact with PyArmNN is through the `ArmnnNetworkExecutor` class, which will handle the network creation step for you.
+
+```python
+# common/network_executor.py
+# The provided wav2letter model is in .tflite format so we use TfLiteParser() to import the graph
+if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+network = parser.CreateNetworkFromBinaryFile(model_file)
+...
+# Optimize the network for the list of preferred backends
+opt_network, messages = ann.Optimize(
+ network, preferred_backends, self.runtime.GetDeviceSpec(), ann.OptimizerOptions()
+ )
+# Load the optimized network onto the runtime device
+self.network_id, _ = self.runtime.LoadNetwork(opt_network)
+# Get the input and output binding information
+self.input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+self.output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+```
+
+### Automatic speech recognition pipeline
+
+The `MFCC` class is used to extract the Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) from a given audio frame to be used as features for the network. MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+
+After all the MFCCs needed for an inference have been extracted from the audio data, we convolve them with 1-dimensional Savitzky-Golay filters to compute the first and second MFCC derivatives with respect to time. The MFCCs and the derivatives are concatenated to make the input tensor for the model.
+
+```python
+# preprocess.py
+# Extract MFCC features
+log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
+mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
+...
+# Compute first and second derivatives (delta and delta-delta respectively) by passing a
+# Savitzky-Golay filter as a 1D convolution over the features
+for i in range(features.shape[1]):
+ idelta = np.convolve(features[:, i], self.__savgol_order1_coeffs, 'same')
+ mfcc_delta_np[:, i] = (idelta)
+ ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
+ mfcc_delta2_np[:, i] = (ideltadelta)
+```
+
+```python
+# audio_utils.py
+# Quantize the input data and create input tensors with PyArmNN
+input_tensor = quantize_input(input_tensor, input_binding_info)
+input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+```
+
+Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
+
+After creating the workload tensors, the compute device performs inference for the loaded network by using the `EnqueueWorkload()` function of the runtime context. Calling the `workload_tensors_to_ndarray()` function obtains the inference results as a list of ndarrays.
+
+```python
+# common/network_executor.py
+status = runtime.EnqueueWorkload(net_id, input_tensors, self.output_tensors)
+self.output_result = ann.workload_tensors_to_ndarray(self.output_tensors)
+```
+
+The output from the inference must be decoded to obtain the recognised characters from the speech. A simple greedy decoder classifies the results by taking the highest element of the output as a key for the labels dictionary. The value returned is a character which is appended to a list, and the list is filtered to remove unwanted characters. The produced string is displayed on the console.
+
+## Next steps
+
+Having now gained a solid understanding of performing automatic speech recognition with PyArmNN, you are able to take control and create your own application. For your next steps we suggest to first implement your own network, which can be done by updating the parameters of `ModelParams` and `MfccParams` to match your custom model. The `ArmnnNetworkExecutor` class will handle the network optimisation and loading for you.
+
+An important step to improving accuracy of the generated output sentences is by providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
+
+In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce.
diff --git a/python/pyarmnn/examples/speech_recognition/__init__.py b/python/pyarmnn/examples/speech_recognition/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/__init__.py
diff --git a/python/pyarmnn/examples/speech_recognition/audio_capture.py b/python/pyarmnn/examples/speech_recognition/audio_capture.py
new file mode 100644
index 0000000000..9f28d1006e
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/audio_capture.py
@@ -0,0 +1,56 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Contains AudioCapture class for capturing chunks of audio data from file."""
+
+from typing import Generator
+
+import numpy as np
+import soundfile as sf
+
+
+class ModelParams:
+ def __init__(self, model_file_path: str):
+ """Defines sampling parameters for model used.
+
+ Args:
+ model_file_path: Path to ASR model to use.
+ """
+ self.path = model_file_path
+ self.mono = True
+ self.dtype = np.float32
+ self.samplerate = 16000
+ self.min_samples = 167392
+
+
+class AudioCapture:
+ def __init__(self, model_params):
+ """Sampling parameters for model used."""
+ self.model_params = model_params
+
+ def from_audio_file(self, audio_file_path, overlap=31712) -> Generator[np.ndarray, None, None]:
+ """Creates a generator that yields audio data from a file. Data is padded with
+ zeros if necessary to make up minimum number of samples.
+
+ Args:
+ audio_file_path: Path to audio file provided by user.
+ overlap: The overlap with previous buffer. We need the offset to be the same as the inner context
+ of the mfcc output, which is sized as 100 x 39. Each mfcc compute produces 1 x 39 vector,
+ and consumes 160 audio samples. The default overlap is then calculated to be 47712 - (160 x 100)
+ where 47712 is the min_samples needed for 1 inference of wav2letter.
+
+ Yields:
+ Blocks of audio data of minimum sample size.
+ """
+ with sf.SoundFile(audio_file_path) as audio_file:
+ for block in audio_file.blocks(
+ blocksize=self.model_params.min_samples,
+ dtype=self.model_params.dtype,
+ always_2d=True,
+ fill_value=0,
+ overlap=overlap
+ ):
+ # Convert to mono if specified
+ if self.model_params.mono and block.shape[0] > 1:
+ block = np.mean(block, axis=1)
+ yield block
diff --git a/python/pyarmnn/examples/speech_recognition/audio_utils.py b/python/pyarmnn/examples/speech_recognition/audio_utils.py
new file mode 100644
index 0000000000..a522a0e2a7
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/audio_utils.py
@@ -0,0 +1,128 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Utilities for speech recognition apps."""
+
+import numpy as np
+import pyarmnn as ann
+
+
+def decode(model_output: np.ndarray, labels: dict) -> str:
+ """Decodes the integer encoded results from inference into a string.
+
+ Args:
+ model_output: Results from running inference.
+ labels: Dictionary of labels keyed on the classification index.
+
+ Returns:
+ Decoded string.
+ """
+ top1_results = [labels[np.argmax(row[0])] for row in model_output]
+ return filter_characters(top1_results)
+
+
+def filter_characters(results: list) -> str:
+ """Filters unwanted and duplicate characters.
+
+ Args:
+ results: List of top 1 results from inference.
+
+ Returns:
+ Final output string to present to user.
+ """
+ text = ""
+ for i in range(len(results)):
+ if results[i] == "$":
+ continue
+ elif i + 1 < len(results) and results[i] == results[i + 1]:
+ continue
+ else:
+ text += results[i]
+ return text
+
+
+def display_text(text: str):
+ """Presents the results on the console.
+
+ Args:
+ text: Results of performing ASR on the input audio data.
+ """
+ print(text, sep="", end="", flush=True)
+
+
+def quantize_input(data, input_binding_info):
+ """Quantize the float input to (u)int8 ready for inputting to model."""
+ if data.ndim != 2:
+ raise RuntimeError("Audio data must have 2 dimensions for quantization")
+
+ quant_scale = input_binding_info[1].GetQuantizationScale()
+ quant_offset = input_binding_info[1].GetQuantizationOffset()
+ data_type = input_binding_info[1].GetDataType()
+
+ if data_type == ann.DataType_QAsymmS8:
+ data_type = np.int8
+ elif data_type == ann.DataType_QAsymmU8:
+ data_type = np.uint8
+ else:
+ raise ValueError("Could not quantize data to required data type")
+
+ d_min = np.iinfo(data_type).min
+ d_max = np.iinfo(data_type).max
+
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = np.clip(data[row, col], d_min, d_max)
+ data = data.astype(data_type)
+ return data
+
+
+def decode_text(is_first_window, labels, output_result):
+ """
+ Slices the text appropriately depending on the window, and decodes for wav2letter output.
+ * First run, take the left context, and inner context.
+ * Every other run, take the inner context.
+ Stores the current right context, and updates it for each inference. Will get used after last inference
+
+ Args:
+ is_first_window: Boolean to show if it is the first window we are running inference on
+ labels: the label set
+ output_result: the output from the inference
+ text: the current text string, to be displayed at the end
+ Returns:
+ current_r_context: the current right context
+ text: the current text string, with the latest output decoded and appended
+ """
+
+ if is_first_window:
+ # Since it's the first inference, keep the left context, and inner context, and decode
+ text = decode(output_result[0][0:472], labels)
+ else:
+ # Only decode the inner context
+ text = decode(output_result[0][49:472], labels)
+
+ # Store the right context, we will need it after the last inference
+ current_r_context = decode(output_result[0][473:521], labels)
+ return current_r_context, text
+
+
+def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+ """
+ Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
+ input tensors.
+
+ Args:
+ audio_data: The audio data to process
+ mfcc_instance: the mfcc class instance
+ input_binding_info: the model input binding info
+ mfcc_preprocessor: the mfcc preprocessor instance
+ Returns:
+ input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ """
+
+ data_type = input_binding_info[1].GetDataType()
+ input_tensor = mfcc_preprocessor.extract_features(audio_data)
+ if data_type != ann.DataType_Float32:
+ input_tensor = quantize_input(input_tensor, input_binding_info)
+ input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+ return input_tensors
diff --git a/python/pyarmnn/examples/speech_recognition/preprocess.py b/python/pyarmnn/examples/speech_recognition/preprocess.py
new file mode 100644
index 0000000000..553ddba5de
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/preprocess.py
@@ -0,0 +1,260 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
+
+import numpy as np
+
+
+class MFCCParams:
+ def __init__(self, sampling_freq, num_fbank_bins,
+ mel_lo_freq, mel_hi_freq, num_mfcc_feats, frame_len, use_htk_method, n_FFT):
+ self.sampling_freq = sampling_freq
+ self.num_fbank_bins = num_fbank_bins
+ self.mel_lo_freq = mel_lo_freq
+ self.mel_hi_freq = mel_hi_freq
+ self.num_mfcc_feats = num_mfcc_feats
+ self.frame_len = frame_len
+ self.use_htk_method = use_htk_method
+ self.n_FFT = n_FFT
+
+
+class MFCC:
+
+ def __init__(self, mfcc_params):
+ self.mfcc_params = mfcc_params
+ self.FREQ_STEP = 200.0 / 3
+ self.MIN_LOG_HZ = 1000.0
+ self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
+ self.LOG_STEP = 1.8562979903656 / 27.0
+ self.__frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
+ self.__filter_bank_initialised = False
+ self.__frame = np.zeros(self.__frame_len_padded)
+ self.__buffer = np.zeros(self.__frame_len_padded)
+ self.__filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
+ self.__mel_filter_bank = self.create_mel_filter_bank()
+ self.__np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_FFT / 2) + 1])
+
+ for i in range(self.mfcc_params.num_fbank_bins):
+ k = 0
+ for j in range(int(self.__filter_bank_filter_first[i]), int(self.__filter_bank_filter_last[i]) + 1):
+ self.__np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
+ k += 1
+
+ def mel_scale(self, freq, use_htk_method):
+ """
+ Gets the mel scale for a particular sample frequency.
+
+ Args:
+ freq: The sampling frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the mel scale
+ """
+ if use_htk_method:
+ return 1127.0 * np.log(1.0 + freq / 700.0)
+ else:
+ mel = freq / self.FREQ_STEP
+
+ if freq >= self.MIN_LOG_HZ:
+ mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP
+ return mel
+
+ def inv_mel_scale(self, mel_freq, use_htk_method):
+ """
+ Gets the sample frequency for a particular mel.
+
+ Args:
+ mel_freq: The mel frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the sample frequency
+ """
+ if use_htk_method:
+ return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
+ else:
+ freq = self.FREQ_STEP * mel_freq
+
+ if mel_freq >= self.MIN_LOG_MEL:
+ freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
+ return freq
+
+ def mfcc_compute(self, audio_data):
+ """
+ Extracts the MFCC for a single frame.
+
+ Args:
+ audio_data: The audio data to process.
+
+ Returns:
+ the MFCC features
+ """
+ if len(audio_data) != self.mfcc_params.frame_len:
+ raise ValueError(
+ f"audio_data buffer size {len(audio_data)} does not match the frame length {self.mfcc_params.frame_len}")
+
+ audio_data = np.array(audio_data)
+ spec = np.abs(np.fft.rfft(np.hanning(self.mfcc_params.n_FFT + 1)[0:self.mfcc_params.n_FFT] * audio_data,
+ self.mfcc_params.n_FFT)) ** 2
+ mel_energy = np.dot(self.__np_mel_bank.astype(np.float32),
+ np.transpose(spec).astype(np.float32))
+
+ mel_energy += 1e-10
+ log_mel_energy = 10.0 * np.log10(mel_energy)
+ top_db = 80.0
+
+ log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
+
+ mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
+
+ return mfcc_feats
+
+ def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
+ """
+ Creates the Discrete Cosine Transform matrix to be used in the compute function.
+
+ Args:
+ num_fbank_bins: The number of filter bank bins
+ num_mfcc_feats: the number of MFCC features
+
+ Returns:
+ the DCT matrix
+ """
+ dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
+ for k in range(num_mfcc_feats):
+ for n in range(num_fbank_bins):
+ if k == 0:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (4 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+ else:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (2 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+
+ dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
+ return dct_m
+
+ def create_mel_filter_bank(self):
+ """
+ Creates the Mel filter bank.
+
+ Returns:
+ the mel filter bank
+ """
+ num_fft_bins = int(self.__frame_len_padded / 2)
+ fft_bin_width = self.mfcc_params.sampling_freq / self.__frame_len_padded
+
+ mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, False)
+ mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, False)
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
+
+ this_bin = np.zeros(num_fft_bins)
+ mel_fbank = [0] * self.mfcc_params.num_fbank_bins
+
+ for bin_num in range(self.mfcc_params.num_fbank_bins):
+ left_mel = mel_low_freq + bin_num * mel_freq_delta
+ center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
+ right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta
+ first_index = last_index = -1
+
+ for i in range(num_fft_bins):
+ freq = (fft_bin_width * i)
+ mel = self.mel_scale(freq, False)
+ this_bin[i] = 0.0
+
+ if (mel > left_mel) and (mel < right_mel):
+ if mel <= center_mel:
+ weight = (mel - left_mel) / (center_mel - left_mel)
+ else:
+ weight = (right_mel - mel) / (right_mel - center_mel)
+
+ enorm = 2.0 / (self.inv_mel_scale(right_mel, False) - self.inv_mel_scale(left_mel, False))
+ weight *= enorm
+ this_bin[i] = weight
+
+ if first_index == -1:
+ first_index = i
+ last_index = i
+
+ self.__filter_bank_filter_first[bin_num] = first_index
+ self.__filter_bank_filter_last[bin_num] = last_index
+ mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
+ j = 0
+
+ for i in range(first_index, last_index + 1):
+ mel_fbank[bin_num][j] = this_bin[i]
+ j += 1
+
+ return mel_fbank
+
+
+class Preprocessor:
+
+ def __init__(self, mfcc, model_input_size, stride):
+ self.model_input_size = model_input_size
+ self.stride = stride
+
+ # Savitzky - Golay differential filters
+ self.__savgol_order1_coeffs = np.array([6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
+ 1.66666667e-02, -3.46944695e-18, -1.66666667e-02,
+ -3.33333333e-02, -5.00000000e-02, -6.66666667e-02])
+
+ self.savgol_order2_coeffs = np.array([0.06060606, 0.01515152, -0.01731602,
+ -0.03679654, -0.04329004, -0.03679654,
+ -0.01731602, 0.01515152, 0.06060606])
+
+ self.__mfcc_calc = mfcc
+
+ def __normalize(self, values):
+ """
+ Normalize values to mean 0 and std 1
+ """
+ ret_val = (values - np.mean(values)) / np.std(values)
+ return ret_val
+
+ def __get_features(self, features, mfcc_instance, audio_data):
+ idx = 0
+ while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
+ features.extend(mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)]))
+ idx += self.stride
+
+ def extract_features(self, audio_data):
+ """
+ Extracts the MFCC features, and calculates each features first and second order derivative.
+ The matrix returned should be sized appropriately for input to the model, based
+ on the model info specified in the MFCC instance.
+
+ Args:
+ mfcc_instance: The instance of MFCC used for this calculation
+ audio_data: the audio data to be used for this calculation
+ Returns:
+ the derived MFCC feature vector, sized appropriately for inference
+ """
+
+ num_samples_per_inference = ((self.model_input_size - 1)
+ * self.stride) + self.__mfcc_calc.mfcc_params.frame_len
+ if len(audio_data) < num_samples_per_inference:
+ raise ValueError("audio_data size for feature extraction is smaller than "
+ "the expected number of samples needed for inference")
+
+ features = []
+ self.__get_features(features, self.__mfcc_calc, np.asarray(audio_data))
+ features = np.reshape(np.array(features), (self.model_input_size, self.__mfcc_calc.mfcc_params.num_mfcc_feats))
+
+ mfcc_delta_np = np.zeros_like(features)
+ mfcc_delta2_np = np.zeros_like(features)
+
+ for i in range(features.shape[1]):
+ idelta = np.convolve(features[:, i], self.__savgol_order1_coeffs, 'same')
+ mfcc_delta_np[:, i] = (idelta)
+ ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
+ mfcc_delta2_np[:, i] = (ideltadelta)
+
+ features = np.concatenate((self.__normalize(features), self.__normalize(mfcc_delta_np),
+ self.__normalize(mfcc_delta2_np)), axis=1)
+
+ return np.float32(features)
diff --git a/python/pyarmnn/examples/speech_recognition/requirements.txt b/python/pyarmnn/examples/speech_recognition/requirements.txt
new file mode 100644
index 0000000000..4b8f3e6d24
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/requirements.txt
@@ -0,0 +1,2 @@
+numpy>=1.19.2
+soundfile>=0.10.3 \ No newline at end of file
diff --git a/python/pyarmnn/examples/speech_recognition/run_audio_file.py b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
new file mode 100644
index 0000000000..c7e4c6bc31
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
@@ -0,0 +1,94 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Automatic speech recognition with PyArmNN demo for processing audio clips to text."""
+
+import sys
+import os
+from argparse import ArgumentParser
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+from network_executor import ArmnnNetworkExecutor
+from utils import dict_labels
+from preprocess import MFCCParams, Preprocessor, MFCC
+from audio_capture import AudioCapture, ModelParams
+from audio_utils import decode_text, prepare_input_tensors, display_text
+
+
+def parse_args():
+ parser = ArgumentParser(description="ASR with PyArmNN")
+ parser.add_argument(
+ "--audio_file_path",
+ required=True,
+ type=str,
+ help="Path to the audio file to perform ASR",
+ )
+ parser.add_argument(
+ "--model_file_path",
+ required=True,
+ type=str,
+ help="Path to ASR model to use",
+ )
+ parser.add_argument(
+ "--labels_file_path",
+ required=True,
+ type=str,
+ help="Path to text file containing labels to map to model output",
+ )
+ parser.add_argument(
+ "--preferred_backends",
+ type=str,
+ nargs="+",
+ default=["CpuAcc", "CpuRef"],
+ help="""List of backends in order of preference for optimizing
+ subgraphs, falling back to the next backend in the list on unsupported
+ layers. Defaults to [CpuAcc, CpuRef]""",
+ )
+ return parser.parse_args()
+
+
+def main(args):
+ # Read command line args
+ audio_file = args.audio_file_path
+ model = ModelParams(args.model_file_path)
+ labels = dict_labels(args.labels_file_path)
+
+ # Create the ArmNN inference runner
+ network = ArmnnNetworkExecutor(model.path, args.preferred_backends)
+
+ audio_capture = AudioCapture(model)
+ buffer = audio_capture.from_audio_file(audio_file)
+
+ # Create the preprocessor
+ mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=128, mel_lo_freq=0, mel_hi_freq=8000,
+ num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_FFT=512)
+ mfcc = MFCC(mfcc_params)
+ preprocessor = Preprocessor(mfcc, model_input_size=1044, stride=160)
+
+ text = ""
+ current_r_context = ""
+ is_first_window = True
+
+ print("Processing Audio Frames...")
+ for audio_data in buffer:
+ # Prepare the input Tensors
+ input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+
+ # Run inference
+ output_result = network.run(input_tensors)
+
+ # Slice and Decode the text, and store the right context
+ current_r_context, text = decode_text(is_first_window, labels, output_result)
+
+ is_first_window = False
+
+ display_text(text)
+
+ print(current_r_context, flush=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/python/pyarmnn/examples/speech_recognition/tests/conftest.py b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
new file mode 100644
index 0000000000..730c291cfa
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
@@ -0,0 +1,34 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import ntpath
+
+import urllib.request
+
+import pytest
+
+script_dir = os.path.dirname(__file__)
+
+@pytest.fixture(scope="session")
+def test_data_folder(request):
+ """
+ This fixture returns path to folder with shared test resources among all tests
+ """
+
+ data_dir = os.path.join(script_dir, "testdata")
+
+ if not os.path.exists(data_dir):
+ os.mkdir(data_dir)
+
+ files_to_download = ["https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master"
+ "/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"]
+
+ for file in files_to_download:
+ path, filename = ntpath.split(file)
+ file_path = os.path.join(script_dir, "testdata", filename)
+ if not os.path.exists(file_path):
+ print("\nDownloading test file: " + file_path + "\n")
+ urllib.request.urlretrieve(file, file_path)
+
+ return data_dir
diff --git a/python/pyarmnn/examples/speech_recognition/tests/context.py b/python/pyarmnn/examples/speech_recognition/tests/context.py
new file mode 100644
index 0000000000..a810010e9f
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/context.py
@@ -0,0 +1,13 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import sys
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'common'))
+import utils as common_utils
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+import audio_capture
+import audio_utils
+import preprocess
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py b/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
new file mode 100644
index 0000000000..281d0df587
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
@@ -0,0 +1,17 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+import numpy as np
+
+from context import audio_capture
+
+
+def test_audio_file(test_data_folder):
+ audio_file = os.path.join(test_data_folder, "myVoiceIsMyPassportVerifyMe04.wav")
+ capture = audio_capture.AudioCapture(audio_capture.ModelParams(""))
+ buffer = capture.from_audio_file(audio_file)
+ audio_data = next(buffer)
+ assert audio_data.shape == (47712,)
+ assert audio_data.dtype == np.float32
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
new file mode 100644
index 0000000000..3b99e6504a
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
@@ -0,0 +1,28 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+import numpy as np
+
+from context import common_utils
+from context import audio_utils
+
+
+def test_labels(test_data_folder):
+ labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
+ labels = common_utils.dict_labels(labels_file)
+ assert len(labels) == 29
+ assert labels[26] == "\'"
+ assert labels[27] == r" "
+ assert labels[28] == "$"
+
+
+def test_decoder(test_data_folder):
+ labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
+ labels = common_utils.dict_labels(labels_file)
+
+ output_tensor = os.path.join(test_data_folder, "inf_out.npy")
+ encoded = np.load(output_tensor)
+ decoded_text = audio_utils.decode(encoded, labels)
+ assert decoded_text == "and he walkd immediately out of the apartiment by anothe"
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py b/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
new file mode 100644
index 0000000000..d692ab51c8
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
@@ -0,0 +1,286 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+
+from context import preprocess
+
+test_wav = [
+ -3,0,1,-1,2,3,-2,2,
+ 1,-2,0,3,-1,8,3,2,
+ -1,-1,2,7,3,5,6,6,
+ 6,12,5,6,3,3,5,4,
+ 4,6,7,7,7,3,7,2,
+ 8,4,4,2,-4,-1,-1,-4,
+ 2,1,-1,-4,0,-7,-6,-2,
+ -5,1,-5,-1,-7,-3,-3,-7,
+ 0,-3,3,-5,0,1,-2,-2,
+ -3,-3,-7,-3,-2,-6,-5,-8,
+ -2,-8,4,-9,-4,-9,-5,-5,
+ -3,-9,-3,-9,-1,-7,-4,1,
+ -3,2,-8,-4,-4,-5,1,-3,
+ -1,0,-1,-2,-3,-2,-4,-1,
+ 1,-1,3,0,3,2,0,0,
+ 0,-3,1,1,0,8,3,4,
+ 1,5,6,4,7,3,3,0,
+ 3,6,7,6,4,5,9,9,
+ 5,5,8,1,6,9,6,6,
+ 7,1,8,1,5,0,5,5,
+ 0,3,2,7,2,-3,3,0,
+ 3,0,0,0,2,0,-1,-1,
+ -2,-3,-8,0,1,0,-3,-3,
+ -3,-2,-3,-3,-4,-6,-2,-8,
+ -9,-4,-1,-5,-3,-3,-4,-3,
+ -6,3,0,-1,-2,-9,-4,-2,
+ 2,-1,3,-5,-5,-2,0,-2,
+ 0,-1,-3,1,-2,9,4,5,
+ 2,2,1,0,-6,-2,0,0,
+ 0,-1,4,-4,3,-7,-1,5,
+ -6,-1,-5,4,3,9,-2,1,
+ 3,0,0,-2,1,2,1,1,
+ 0,3,2,-1,3,-3,7,0,
+ 0,3,2,2,-2,3,-2,2,
+ -3,4,-1,-1,-5,-1,-3,-2,
+ 1,-1,3,2,4,1,2,-2,
+ 0,2,7,0,8,-3,6,-3,
+ 6,1,2,-3,-1,-1,-1,1,
+ -2,2,1,2,0,-2,3,-2,
+ 3,-2,1,0,-3,-1,-2,-4,
+ -6,-5,-8,-1,-4,0,-3,-1,
+ -1,-1,0,-2,-3,-7,-1,0,
+ 1,5,0,5,1,1,-3,0,
+ -6,3,-8,4,-8,6,-6,1,
+ -6,-2,-5,-6,0,-5,4,-1,
+ 4,-2,1,2,1,0,-2,0,
+ 0,2,-2,2,-5,2,0,-2,
+ 1,-2,0,5,1,0,1,5,
+ 0,8,3,2,2,0,5,-2,
+ 3,1,0,1,0,-2,-1,-3,
+ 1,-1,3,0,3,0,-2,-1,
+ -4,-4,-4,-1,-4,-4,-3,-6,
+ -3,-7,-3,-1,-2,0,-5,-4,
+ -7,-3,-2,-2,1,2,2,8,
+ 5,4,2,4,3,5,0,3,
+ 3,6,4,2,2,-2,4,-2,
+ 3,3,2,1,1,4,-5,2,
+ -3,0,-1,1,-2,2,5,1,
+ 4,2,3,1,-1,1,0,6,
+ 0,-2,-1,1,-1,2,-5,-1,
+ -5,-1,-6,-3,-3,2,4,0,
+ -1,-5,3,-4,-1,-3,-4,1,
+ -4,1,-1,-1,0,-5,-4,-2,
+ -1,-1,-3,-7,-3,-3,4,4,
+]
+
+def test_mel_scale_function_with_htk_true():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.mel_scale(16, True)
+
+ assert np.isclose(mel, 25.470010570730597)
+
+
+def test_mel_scale_function_with_htk_false():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.mel_scale(16, False)
+
+ assert np.isclose(mel, 0.24)
+
+
+def test_inverse_mel_scale_function_with_htk_true():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.inv_mel_scale(16, True)
+
+ assert np.isclose(mel, 10.008767240008943)
+
+
+def test_inverse_mel_scale_function_with_htk_false():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.inv_mel_scale(16, False)
+
+ assert np.isclose(mel, 1071.170287494467)
+
+
+def test_create_mel_filter_bank():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel_filter_bank = mfcc_inst.create_mel_filter_bank()
+
+ assert len(mel_filter_bank) == 128
+
+ assert str(mel_filter_bank[0]) == "[0.02837754]"
+ assert str(mel_filter_bank[1]) == "[0.01438901 0.01398853]"
+ assert str(mel_filter_bank[2]) == "[0.02877802]"
+ assert str(mel_filter_bank[3]) == "[0.04236608]"
+ assert str(mel_filter_bank[4]) == "[0.00040047 0.02797707]"
+ assert str(mel_filter_bank[5]) == "[0.01478948 0.01358806]"
+ assert str(mel_filter_bank[50]) == "[0.03298853]"
+ assert str(mel_filter_bank[100]) == "[0.00260166 0.00588759 0.00914814 0.00798015 0.00476919 0.00158245]"
+
+
+def test_mfcc_compute():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.array(test_wav) / (2 ** 15)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ mfcc_feats = mfcc_inst.mfcc_compute(audio_data)
+
+ assert np.isclose((mfcc_feats[0]), -834.9656973095651)
+ assert np.isclose((mfcc_feats[1]), 21.026915475076322)
+ assert np.isclose((mfcc_feats[2]), 18.628541708201688)
+ assert np.isclose((mfcc_feats[3]), 7.341153529494758)
+ assert np.isclose((mfcc_feats[4]), 18.907974386153214)
+ assert np.isclose((mfcc_feats[5]), -5.360387487466194)
+ assert np.isclose((mfcc_feats[6]), 6.523572638527085)
+ assert np.isclose((mfcc_feats[7]), -11.270643644983316)
+ assert np.isclose((mfcc_feats[8]), 8.375177203773777)
+ assert np.isclose((mfcc_feats[9]), 12.06721844362991)
+ assert np.isclose((mfcc_feats[10]), 8.30815892468875)
+ assert np.isclose((mfcc_feats[11]), -13.499911910889917)
+ assert np.isclose((mfcc_feats[12]), -18.176121251436165)
+
+
+def test_sliding_window_for_small_num_samples():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ mode_input_size = 9
+ stride = 160
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.array(test_wav) / (2 ** 15)
+
+ full_audio_data = np.tile(audio_data, 9)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
+
+ input_tensor = preprocessor.extract_features(full_audio_data)
+
+ assert np.isclose(input_tensor[0][0], -3.4660944830426454)
+ assert np.isclose(input_tensor[0][1], 0.3587718932127629)
+ assert np.isclose(input_tensor[0][2], 0.3480551325669172)
+ assert np.isclose(input_tensor[0][3], 0.2976191917228921)
+ assert np.isclose(input_tensor[0][4], 0.3493037340849936)
+ assert np.isclose(input_tensor[0][5], 0.2408643285767937)
+ assert np.isclose(input_tensor[0][6], 0.2939659585037282)
+ assert np.isclose(input_tensor[0][7], 0.2144552669573928)
+ assert np.isclose(input_tensor[0][8], 0.302239565899944)
+ assert np.isclose(input_tensor[0][9], 0.3187368787077345)
+ assert np.isclose(input_tensor[0][10], 0.3019401051295793)
+ assert np.isclose(input_tensor[0][11], 0.20449412797602678)
+
+ assert np.isclose(input_tensor[0][38], -0.18751440767749533)
+
+
+def test_sliding_window_for_wav_2_letter_sized_input():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ mode_input_size = 296
+ stride = 160
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.zeros(47712, dtype=int)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
+
+ input_tensor = preprocessor.extract_features(audio_data)
+
+ assert len(input_tensor[0]) == 39
+ assert len(input_tensor) == 296
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy b/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
new file mode 100644
index 0000000000..a6f9ec0c70
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
Binary files differ
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt b/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
new file mode 100644
index 0000000000..d7485b7da2
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
@@ -0,0 +1,29 @@
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+'
+
+$ \ No newline at end of file