aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/keyword_spotting
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/keyword_spotting')
-rw-r--r--python/pyarmnn/examples/keyword_spotting/README.MD189
-rw-r--r--python/pyarmnn/examples/keyword_spotting/__init__.py0
-rw-r--r--python/pyarmnn/examples/keyword_spotting/audio_utils.py31
-rw-r--r--python/pyarmnn/examples/keyword_spotting/requirements.txt5
-rw-r--r--python/pyarmnn/examples/keyword_spotting/run_audio_classification.py136
5 files changed, 361 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/keyword_spotting/README.MD b/python/pyarmnn/examples/keyword_spotting/README.MD
new file mode 100644
index 0000000000..4299fa0fd4
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/README.MD
@@ -0,0 +1,189 @@
+# Keyword Spotting with PyArmNN
+
+This sample application guides the user to perform Keyword Spotting (KWS) with PyArmNN API.
+
+## Prerequisites
+
+### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'26.0.0'
+```
+
+### Dependencies
+
+Install the PortAudio package:
+
+```bash
+$ sudo apt-get install libsndfile1 libportaudio2
+```
+
+Install the required Python modules:
+
+```bash
+$ pip install -r requirements.txt
+```
+
+### Model
+
+The model we are using is the [DS CNN Large](https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/) which can be found in the [Arm Model Zoo repository](
+https://github.com/ARM-software/ML-zoo/tree/master/models).
+
+A small selection of suitable wav files containing keywords can be found [here](https://git.mlplatform.org/ml/ethos-u/ml-embedded-evaluation-kit.git/plain/resources/kws/samples/).
+
+Labels for this model are defined within run_audio_classification.py.
+
+## Performing Keyword Spotting
+
+### Processing Audio Files
+
+Please ensure that your audio file has a sampling rate of 16000Hz.
+
+To run KWS on an audio file, use the following command:
+
+```bash
+$ python run_audio_classification.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model>
+```
+
+You may also add the optional flags:
+
+* `--preferred_backends`
+
+ * Takes the preferred backends in preference order, separated by whitespace. For example, passing in "CpuAcc CpuRef" will be read as list ["CpuAcc", "CpuRef"] (defaults to this list)
+
+ * CpuAcc represents the CPU backend
+
+ * GpuAcc represents the GPU backend
+
+ * CpuRef represents the CPU reference kernels
+
+* `--help` prints all available options to screen
+
+
+### Processing Audio Streams
+
+To run KWS on a live audio stream, use the following command:
+
+```bash
+$ python run_audio_classification.py --model_file_path <path/to/your_model> --duration (optional)
+```
+You will be prompted to select an input microphone and inference will commence
+after 3 seconds.
+
+
+You may also add the following optional flag in addition to those for run_audio_file.py:
+
+* `--duration`
+
+ * Integer number of seconds to perform inference. Default is to continue indefinitely.
+
+## Application Overview
+
+1. [Initialization](#initialization)
+
+2. [Creating a network](#creating-a-network)
+
+3. [Keyword Spotting Pipeline](#keyword-spotting-pipeline)
+
+### Initialization
+
+The application parses the supplied user arguments and loads the audio file or stream in chunks through the `capture_audio()` method which accepts sampling criteria as an `AudioCaptureParams` tuple.
+
+With KWS from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size defined in AudioCaptureParams.
+
+MFCC features are extracted from each block based on criteria defined in the `MFCCParams` tuple. These extracted features constitute the input tensors for the model.
+
+To interpret the inference result of the loaded network; the application passes the label dictionary defined in run_audio_classification.py to a decoder and displays the result.
+
+### Creating a network
+
+A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
+
+Arm NN supports optimized execution on multiple CPU, GPU, and Ethos-N devices. Before executing a graph, the application must select the appropriate device context by using `IRuntime()` to create a runtime context with default options. We can optimize the imported graph by specifying a list of backends in order of preference and implementing backend-specific optimizations, identified by a unique string, for example CpuAcc, GpuAcc, CpuRef represent the accelerated CPU and GPU backends and the CPU reference kernels respectively.
+
+Arm NN splits the entire graph into subgraphs based on these backends. Each subgraph is then optimized, and the corresponding subgraph in the original graph is substituted with its optimized version.
+
+The `Optimize()` function optimizes the graph for inference, then `LoadNetwork()` loads the optimized network onto the compute device. The `LoadNetwork()` function also creates the backend-specific workloads for the layers and a backend-specific workload factory.
+
+Parsers extract the input information for the network. The `GetSubgraphInputTensorNames()` function extracts all the input names and the `GetNetworkInputBindingInfo()` function obtains the input binding information of the graph. The input binding information contains all the essential information about the input. This information is a tuple consisting of integer identifiers for bindable layers and tensor information (data type, quantization info, dimension count, total elements).
+
+Similarly, we can get the output binding information for an output layer by using the parser to retrieve output tensor names and calling the `GetNetworkOutputBindingInfo()` function
+
+For this application, the main point of contact with PyArmNN is through the `ArmnnNetworkExecutor` class, which will handle the network creation step for you.
+
+```python
+# common/network_executor.py
+# The provided kws model is in .tflite format so we use TfLiteParser() to import the graph
+if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+network = parser.CreateNetworkFromBinaryFile(model_file)
+...
+# Optimize the network for the list of preferred backends
+opt_network, messages = ann.Optimize(
+ network, preferred_backends, self.runtime.GetDeviceSpec(), ann.OptimizerOptions()
+ )
+# Load the optimized network onto the runtime device
+self.network_id, _ = self.runtime.LoadNetwork(opt_network)
+# Get the input and output binding information
+self.input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+self.output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+```
+
+### Keyword Spotting pipeline
+
+
+Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) are extracted based on criteria defined in the MFCCParams tuple and associated`MFCC Class`.
+MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+
+The `MFCC` class is used in conjunction with the `AudioPreProcessor` class to extract and process MFCC features from a given audio frame.
+
+
+After all the MFCCs needed for an inference have been extracted from the audio data they constitute the input tensors that will be classified by an `ArmnnNetworkExecutor`object.
+
+```python
+# mfcc.py
+# Extract MFCC features from audio_data
+audio_data.resize(self._frame_len_padded)
+spec = self.spectrum_calc(audio_data)
+mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
+ np.transpose(spec).astype(np.float32))
+log_mel_energy = self.log_mel(mel_energy)
+mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
+
+
+```python
+# audio_utils.py
+# Quantize the input data and create input tensors with PyArmNN
+input_tensor = quantize_input(input_tensor, input_binding_info)
+input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+```
+
+Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
+
+After creating the workload tensors, the compute device performs inference for the loaded network by using the `EnqueueWorkload()` function of the runtime context. Calling the `workload_tensors_to_ndarray()` function obtains the inference results as a list of ndarrays.
+
+```python
+# common/network_executor.py
+status = runtime.EnqueueWorkload(net_id, input_tensors, self.output_tensors)
+self.output_result = ann.workload_tensors_to_ndarray(self.output_tensors)
+```
+
+The output from the inference must be decoded to obtain the recognised classification. A simple greedy decoder classifies the results by taking the highest element of the output as a key for the labels dictionary. The value returned is a keyword or unknown/silence which is appended to a list along with the calculated probability. The list elements are displayed on the console if they exceed the threshold value specified in main().
+
+
+## Next steps
+
+Having now gained a solid understanding of performing keyword spotting with PyArmNN, you are able to take control and create your own application. We suggest to first implement your own network, which can be done by updating the parameters of `AudioCaptureParams` and `MFCC_Params` to match your custom model. The `ArmnnNetworkExecutor` class will handle the network optimisation and loading for you.
+
+An important factor in improving accuracy of the generated output is providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
diff --git a/python/pyarmnn/examples/keyword_spotting/__init__.py b/python/pyarmnn/examples/keyword_spotting/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/__init__.py
diff --git a/python/pyarmnn/examples/keyword_spotting/audio_utils.py b/python/pyarmnn/examples/keyword_spotting/audio_utils.py
new file mode 100644
index 0000000000..723c0e38f6
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/audio_utils.py
@@ -0,0 +1,31 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Utilities for speech recognition apps."""
+
+import numpy as np
+
+
+def decode(model_output: np.ndarray, labels: dict) -> list:
+ """Decodes the integer encoded results from inference into a string.
+
+ Args:
+ model_output: Results from running inference.
+ labels: Dictionary of labels keyed on the classification index.
+
+ Returns:
+ Decoded string.
+ """
+ results = [labels[np.argmax(model_output)], model_output[0][0][np.argmax(model_output)]]
+
+ return results
+
+
+def display_text(text: list):
+ """Presents the results on the console.
+
+ Args:
+ text: Results of performing ASR on the input audio data.
+ """
+ print('Classification: %s' % text[0])
+ print('Probability: %s' % text[1])
diff --git a/python/pyarmnn/examples/keyword_spotting/requirements.txt b/python/pyarmnn/examples/keyword_spotting/requirements.txt
new file mode 100644
index 0000000000..96782eafd0
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/requirements.txt
@@ -0,0 +1,5 @@
+numpy>=1.19.2
+soundfile>=0.10.3
+pytest==6.2.4
+pytest-allclose==1.0.0
+sounddevice==0.4.2
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
new file mode 100644
index 0000000000..6dfa4cc806
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
@@ -0,0 +1,136 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files."""
+
+import sys
+import os
+from argparse import ArgumentParser
+
+import numpy as np
+import sounddevice as sd
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+from network_executor import ArmnnNetworkExecutor
+from utils import prepare_input_tensors, dequantize_output
+from mfcc import AudioPreprocessor, MFCC, MFCCParams
+from audio_utils import decode, display_text
+from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
+
+# Model Specific Labels
+labels = {0: 'silence',
+ 1: 'unknown',
+ 2: 'yes',
+ 3: 'no',
+ 4: 'up',
+ 5: 'down',
+ 6: 'left',
+ 7: 'right',
+ 8: 'on',
+ 9: 'off',
+ 10: 'stop',
+ 11: 'go'}
+
+
+def parse_args():
+ parser = ArgumentParser(description="KWS with PyArmNN")
+ parser.add_argument(
+ "--audio_file_path",
+ required=False,
+ type=str,
+ help="Path to the audio file to perform KWS",
+ )
+ parser.add_argument(
+ "--duration",
+ type=int,
+ default=0,
+ help="""Duration for recording audio in seconds. Values <= 0 result in infinite
+ recording. Defaults to infinite.""",
+ )
+ parser.add_argument(
+ "--model_file_path",
+ required=True,
+ type=str,
+ help="Path to KWS model to use",
+ )
+ parser.add_argument(
+ "--preferred_backends",
+ type=str,
+ nargs="+",
+ default=["CpuAcc", "CpuRef"],
+ help="""List of backends in order of preference for optimizing
+ subgraphs, falling back to the next backend in the list on unsupported
+ layers. Defaults to [CpuAcc, CpuRef]""",
+ )
+ return parser.parse_args()
+
+
+def recognise_speech(audio_data, network, preprocessor, threshold):
+ # Prepare the input Tensors
+ input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ # Run inference
+ output_result = network.run(input_tensors)
+
+ dequantized_result = []
+ for index, ofm in enumerate(output_result):
+ dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
+
+ # Decode the text and display result if above threshold
+ decoded_result = decode(dequantized_result, labels)
+
+ if decoded_result[1] > threshold:
+ display_text(decoded_result)
+
+
+def main(args):
+ # Read command line args and invoke mic streaming if no file path supplied
+ audio_file = args.audio_file_path
+ if args.audio_file_path:
+ streaming_enabled = False
+ else:
+ streaming_enabled = True
+ # Create the ArmNN inference runner
+ network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+
+ # Specify model specific audio data requirements
+ # Overlap value specifies the number of samples to rewind between each data window
+ audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000,
+ mono=True)
+
+ # Create the preprocessor
+ mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
+ num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
+ mfcc = MFCC(mfcc_params)
+ preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320)
+
+ # Set threshold for displaying classification and commence stream or file processing
+ threshold = .90
+ if streaming_enabled:
+ # Initialise audio stream
+ record_stream = CaptureAudioStream(audio_capture_params)
+ record_stream.set_stream_defaults()
+ record_stream.set_recording_duration(args.duration)
+ record_stream.countdown()
+
+ with sd.InputStream(callback=record_stream.callback):
+ print("Recording audio. Please speak.")
+ while record_stream.is_active:
+
+ audio_data = record_stream.capture_data()
+ recognise_speech(audio_data, network, preprocessor, threshold)
+ record_stream.is_first_window = False
+ print("\nFinished recording.")
+
+ # If file path has been supplied read-in and run inference
+ else:
+ print("Processing Audio Frames...")
+ buffer = capture_audio(audio_file, audio_capture_params)
+ for audio_data in buffer:
+ recognise_speech(audio_data, network, preprocessor, threshold)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)