aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/speech_recognition
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/speech_recognition')
-rw-r--r--python/pyarmnn/examples/speech_recognition/README.md158
-rw-r--r--python/pyarmnn/examples/speech_recognition/__init__.py0
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_capture.py56
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_utils.py128
-rw-r--r--python/pyarmnn/examples/speech_recognition/preprocess.py260
-rw-r--r--python/pyarmnn/examples/speech_recognition/requirements.txt2
-rw-r--r--python/pyarmnn/examples/speech_recognition/run_audio_file.py94
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/conftest.py34
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/context.py13
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py17
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_decoder.py28
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py286
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npybin0 -> 4420 bytes
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt29
14 files changed, 1105 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
new file mode 100644
index 0000000000..10a583f123
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -0,0 +1,158 @@
+# Automatic Speech Recognition with PyArmNN
+
+This sample application guides the user to perform automatic speech recognition (ASR) with PyArmNN API.
+
+## Prerequisites
+
+### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'22.0.0'
+```
+
+### Dependencies
+
+Install the PortAudio package:
+
+```bash
+$ sudo apt-get install libsndfile1 libportaudio2
+```
+
+Install the required Python modules:
+
+```bash
+$ pip install -r requirements.txt
+```
+
+## Performing Automatic Speech Recognition
+
+### Processing Audio Files
+
+To run ASR on an audio file, use the following command:
+
+```bash
+$ python run_audio_file.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model> --labels_file_path <path/to/your_labels>
+```
+
+You may also add the optional flags:
+
+* `--preferred_backends`
+
+ * Takes the preferred backends in preference order, separated by whitespace. For example, passing in "CpuAcc CpuRef" will be read as list ["CpuAcc", "CpuRef"] (defaults to this list)
+
+ * CpuAcc represents the CPU backend
+
+ * GpuAcc represents the GPU backend
+
+ * CpuRef represents the CPU reference kernels
+
+* `--help` prints all available options to screen
+
+## Application Overview
+
+1. [Initialization](#initialization)
+
+2. [Creating a network](#creating-a-network)
+
+3. [Automatic speech recognition pipeline](#automatic-speech-recognition-pipeline)
+
+### Initialization
+
+The application parses the supplied user arguments and loads the audio file into the `AudioCapture` class, which initialises the audio source and sets sampling parameters required by the model with `ModelParams` class.
+
+`AudioCapture` helps us to capture chunks of audio data from the source. With ASR from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size.
+
+To interpret the inference result of the loaded network, the application must load the labels that are associated with the model. The `dict_labels()` function creates a dictionary that is keyed on the classification index at the output node of the model. The values of the dictionary are the corresponding characters.
+
+### Creating a network
+
+A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite, TF, and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
+
+Arm NN supports optimized execution on multiple CPU, GPU, and Ethos-N devices. Before executing a graph, the application must select the appropriate device context by using `IRuntime()` to create a runtime context with default options. We can optimize the imported graph by specifying a list of backends in order of preference and implementing backend-specific optimizations, identified by a unique string, for example CpuAcc, GpuAcc, CpuRef represent the accelerated CPU and GPU backends and the CPU reference kernels respectively.
+
+Arm NN splits the entire graph into subgraphs based on these backends. Each subgraph is then optimized, and the corresponding subgraph in the original graph is substituted with its optimized version.
+
+The `Optimize()` function optimizes the graph for inference, then `LoadNetwork()` loads the optimized network onto the compute device. The `LoadNetwork()` function also creates the backend-specific workloads for the layers and a backend-specific workload factory.
+
+Parsers extract the input information for the network. The `GetSubgraphInputTensorNames()` function extracts all the input names and the `GetNetworkInputBindingInfo()` function obtains the input binding information of the graph. The input binding information contains all the essential information about the input. This information is a tuple consisting of integer identifiers for bindable layers and tensor information (data type, quantization info, dimension count, total elements).
+
+Similarly, we can get the output binding information for an output layer by using the parser to retrieve output tensor names and calling the `GetNetworkOutputBindingInfo()` function
+
+For this application, the main point of contact with PyArmNN is through the `ArmnnNetworkExecutor` class, which will handle the network creation step for you.
+
+```python
+# common/network_executor.py
+# The provided wav2letter model is in .tflite format so we use TfLiteParser() to import the graph
+if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+network = parser.CreateNetworkFromBinaryFile(model_file)
+...
+# Optimize the network for the list of preferred backends
+opt_network, messages = ann.Optimize(
+ network, preferred_backends, self.runtime.GetDeviceSpec(), ann.OptimizerOptions()
+ )
+# Load the optimized network onto the runtime device
+self.network_id, _ = self.runtime.LoadNetwork(opt_network)
+# Get the input and output binding information
+self.input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+self.output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+```
+
+### Automatic speech recognition pipeline
+
+The `MFCC` class is used to extract the Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) from a given audio frame to be used as features for the network. MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+
+After all the MFCCs needed for an inference have been extracted from the audio data, we convolve them with 1-dimensional Savitzky-Golay filters to compute the first and second MFCC derivatives with respect to time. The MFCCs and the derivatives are concatenated to make the input tensor for the model.
+
+```python
+# preprocess.py
+# Extract MFCC features
+log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
+mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
+...
+# Compute first and second derivatives (delta and delta-delta respectively) by passing a
+# Savitzky-Golay filter as a 1D convolution over the features
+for i in range(features.shape[1]):
+ idelta = np.convolve(features[:, i], self.__savgol_order1_coeffs, 'same')
+ mfcc_delta_np[:, i] = (idelta)
+ ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
+ mfcc_delta2_np[:, i] = (ideltadelta)
+```
+
+```python
+# audio_utils.py
+# Quantize the input data and create input tensors with PyArmNN
+input_tensor = quantize_input(input_tensor, input_binding_info)
+input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+```
+
+Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
+
+After creating the workload tensors, the compute device performs inference for the loaded network by using the `EnqueueWorkload()` function of the runtime context. Calling the `workload_tensors_to_ndarray()` function obtains the inference results as a list of ndarrays.
+
+```python
+# common/network_executor.py
+status = runtime.EnqueueWorkload(net_id, input_tensors, self.output_tensors)
+self.output_result = ann.workload_tensors_to_ndarray(self.output_tensors)
+```
+
+The output from the inference must be decoded to obtain the recognised characters from the speech. A simple greedy decoder classifies the results by taking the highest element of the output as a key for the labels dictionary. The value returned is a character which is appended to a list, and the list is filtered to remove unwanted characters. The produced string is displayed on the console.
+
+## Next steps
+
+Having now gained a solid understanding of performing automatic speech recognition with PyArmNN, you are able to take control and create your own application. For your next steps we suggest to first implement your own network, which can be done by updating the parameters of `ModelParams` and `MfccParams` to match your custom model. The `ArmnnNetworkExecutor` class will handle the network optimisation and loading for you.
+
+An important step to improving accuracy of the generated output sentences is by providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
+
+In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce.
diff --git a/python/pyarmnn/examples/speech_recognition/__init__.py b/python/pyarmnn/examples/speech_recognition/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/__init__.py
diff --git a/python/pyarmnn/examples/speech_recognition/audio_capture.py b/python/pyarmnn/examples/speech_recognition/audio_capture.py
new file mode 100644
index 0000000000..9f28d1006e
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/audio_capture.py
@@ -0,0 +1,56 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Contains AudioCapture class for capturing chunks of audio data from file."""
+
+from typing import Generator
+
+import numpy as np
+import soundfile as sf
+
+
+class ModelParams:
+ def __init__(self, model_file_path: str):
+ """Defines sampling parameters for model used.
+
+ Args:
+ model_file_path: Path to ASR model to use.
+ """
+ self.path = model_file_path
+ self.mono = True
+ self.dtype = np.float32
+ self.samplerate = 16000
+ self.min_samples = 167392
+
+
+class AudioCapture:
+ def __init__(self, model_params):
+ """Sampling parameters for model used."""
+ self.model_params = model_params
+
+ def from_audio_file(self, audio_file_path, overlap=31712) -> Generator[np.ndarray, None, None]:
+ """Creates a generator that yields audio data from a file. Data is padded with
+ zeros if necessary to make up minimum number of samples.
+
+ Args:
+ audio_file_path: Path to audio file provided by user.
+ overlap: The overlap with previous buffer. We need the offset to be the same as the inner context
+ of the mfcc output, which is sized as 100 x 39. Each mfcc compute produces 1 x 39 vector,
+ and consumes 160 audio samples. The default overlap is then calculated to be 47712 - (160 x 100)
+ where 47712 is the min_samples needed for 1 inference of wav2letter.
+
+ Yields:
+ Blocks of audio data of minimum sample size.
+ """
+ with sf.SoundFile(audio_file_path) as audio_file:
+ for block in audio_file.blocks(
+ blocksize=self.model_params.min_samples,
+ dtype=self.model_params.dtype,
+ always_2d=True,
+ fill_value=0,
+ overlap=overlap
+ ):
+ # Convert to mono if specified
+ if self.model_params.mono and block.shape[0] > 1:
+ block = np.mean(block, axis=1)
+ yield block
diff --git a/python/pyarmnn/examples/speech_recognition/audio_utils.py b/python/pyarmnn/examples/speech_recognition/audio_utils.py
new file mode 100644
index 0000000000..a522a0e2a7
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/audio_utils.py
@@ -0,0 +1,128 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Utilities for speech recognition apps."""
+
+import numpy as np
+import pyarmnn as ann
+
+
+def decode(model_output: np.ndarray, labels: dict) -> str:
+ """Decodes the integer encoded results from inference into a string.
+
+ Args:
+ model_output: Results from running inference.
+ labels: Dictionary of labels keyed on the classification index.
+
+ Returns:
+ Decoded string.
+ """
+ top1_results = [labels[np.argmax(row[0])] for row in model_output]
+ return filter_characters(top1_results)
+
+
+def filter_characters(results: list) -> str:
+ """Filters unwanted and duplicate characters.
+
+ Args:
+ results: List of top 1 results from inference.
+
+ Returns:
+ Final output string to present to user.
+ """
+ text = ""
+ for i in range(len(results)):
+ if results[i] == "$":
+ continue
+ elif i + 1 < len(results) and results[i] == results[i + 1]:
+ continue
+ else:
+ text += results[i]
+ return text
+
+
+def display_text(text: str):
+ """Presents the results on the console.
+
+ Args:
+ text: Results of performing ASR on the input audio data.
+ """
+ print(text, sep="", end="", flush=True)
+
+
+def quantize_input(data, input_binding_info):
+ """Quantize the float input to (u)int8 ready for inputting to model."""
+ if data.ndim != 2:
+ raise RuntimeError("Audio data must have 2 dimensions for quantization")
+
+ quant_scale = input_binding_info[1].GetQuantizationScale()
+ quant_offset = input_binding_info[1].GetQuantizationOffset()
+ data_type = input_binding_info[1].GetDataType()
+
+ if data_type == ann.DataType_QAsymmS8:
+ data_type = np.int8
+ elif data_type == ann.DataType_QAsymmU8:
+ data_type = np.uint8
+ else:
+ raise ValueError("Could not quantize data to required data type")
+
+ d_min = np.iinfo(data_type).min
+ d_max = np.iinfo(data_type).max
+
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = np.clip(data[row, col], d_min, d_max)
+ data = data.astype(data_type)
+ return data
+
+
+def decode_text(is_first_window, labels, output_result):
+ """
+ Slices the text appropriately depending on the window, and decodes for wav2letter output.
+ * First run, take the left context, and inner context.
+ * Every other run, take the inner context.
+ Stores the current right context, and updates it for each inference. Will get used after last inference
+
+ Args:
+ is_first_window: Boolean to show if it is the first window we are running inference on
+ labels: the label set
+ output_result: the output from the inference
+ text: the current text string, to be displayed at the end
+ Returns:
+ current_r_context: the current right context
+ text: the current text string, with the latest output decoded and appended
+ """
+
+ if is_first_window:
+ # Since it's the first inference, keep the left context, and inner context, and decode
+ text = decode(output_result[0][0:472], labels)
+ else:
+ # Only decode the inner context
+ text = decode(output_result[0][49:472], labels)
+
+ # Store the right context, we will need it after the last inference
+ current_r_context = decode(output_result[0][473:521], labels)
+ return current_r_context, text
+
+
+def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+ """
+ Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
+ input tensors.
+
+ Args:
+ audio_data: The audio data to process
+ mfcc_instance: the mfcc class instance
+ input_binding_info: the model input binding info
+ mfcc_preprocessor: the mfcc preprocessor instance
+ Returns:
+ input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ """
+
+ data_type = input_binding_info[1].GetDataType()
+ input_tensor = mfcc_preprocessor.extract_features(audio_data)
+ if data_type != ann.DataType_Float32:
+ input_tensor = quantize_input(input_tensor, input_binding_info)
+ input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+ return input_tensors
diff --git a/python/pyarmnn/examples/speech_recognition/preprocess.py b/python/pyarmnn/examples/speech_recognition/preprocess.py
new file mode 100644
index 0000000000..553ddba5de
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/preprocess.py
@@ -0,0 +1,260 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
+
+import numpy as np
+
+
+class MFCCParams:
+ def __init__(self, sampling_freq, num_fbank_bins,
+ mel_lo_freq, mel_hi_freq, num_mfcc_feats, frame_len, use_htk_method, n_FFT):
+ self.sampling_freq = sampling_freq
+ self.num_fbank_bins = num_fbank_bins
+ self.mel_lo_freq = mel_lo_freq
+ self.mel_hi_freq = mel_hi_freq
+ self.num_mfcc_feats = num_mfcc_feats
+ self.frame_len = frame_len
+ self.use_htk_method = use_htk_method
+ self.n_FFT = n_FFT
+
+
+class MFCC:
+
+ def __init__(self, mfcc_params):
+ self.mfcc_params = mfcc_params
+ self.FREQ_STEP = 200.0 / 3
+ self.MIN_LOG_HZ = 1000.0
+ self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
+ self.LOG_STEP = 1.8562979903656 / 27.0
+ self.__frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
+ self.__filter_bank_initialised = False
+ self.__frame = np.zeros(self.__frame_len_padded)
+ self.__buffer = np.zeros(self.__frame_len_padded)
+ self.__filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
+ self.__mel_filter_bank = self.create_mel_filter_bank()
+ self.__np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_FFT / 2) + 1])
+
+ for i in range(self.mfcc_params.num_fbank_bins):
+ k = 0
+ for j in range(int(self.__filter_bank_filter_first[i]), int(self.__filter_bank_filter_last[i]) + 1):
+ self.__np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
+ k += 1
+
+ def mel_scale(self, freq, use_htk_method):
+ """
+ Gets the mel scale for a particular sample frequency.
+
+ Args:
+ freq: The sampling frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the mel scale
+ """
+ if use_htk_method:
+ return 1127.0 * np.log(1.0 + freq / 700.0)
+ else:
+ mel = freq / self.FREQ_STEP
+
+ if freq >= self.MIN_LOG_HZ:
+ mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP
+ return mel
+
+ def inv_mel_scale(self, mel_freq, use_htk_method):
+ """
+ Gets the sample frequency for a particular mel.
+
+ Args:
+ mel_freq: The mel frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the sample frequency
+ """
+ if use_htk_method:
+ return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
+ else:
+ freq = self.FREQ_STEP * mel_freq
+
+ if mel_freq >= self.MIN_LOG_MEL:
+ freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
+ return freq
+
+ def mfcc_compute(self, audio_data):
+ """
+ Extracts the MFCC for a single frame.
+
+ Args:
+ audio_data: The audio data to process.
+
+ Returns:
+ the MFCC features
+ """
+ if len(audio_data) != self.mfcc_params.frame_len:
+ raise ValueError(
+ f"audio_data buffer size {len(audio_data)} does not match the frame length {self.mfcc_params.frame_len}")
+
+ audio_data = np.array(audio_data)
+ spec = np.abs(np.fft.rfft(np.hanning(self.mfcc_params.n_FFT + 1)[0:self.mfcc_params.n_FFT] * audio_data,
+ self.mfcc_params.n_FFT)) ** 2
+ mel_energy = np.dot(self.__np_mel_bank.astype(np.float32),
+ np.transpose(spec).astype(np.float32))
+
+ mel_energy += 1e-10
+ log_mel_energy = 10.0 * np.log10(mel_energy)
+ top_db = 80.0
+
+ log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
+
+ mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
+
+ return mfcc_feats
+
+ def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
+ """
+ Creates the Discrete Cosine Transform matrix to be used in the compute function.
+
+ Args:
+ num_fbank_bins: The number of filter bank bins
+ num_mfcc_feats: the number of MFCC features
+
+ Returns:
+ the DCT matrix
+ """
+ dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
+ for k in range(num_mfcc_feats):
+ for n in range(num_fbank_bins):
+ if k == 0:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (4 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+ else:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (2 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+
+ dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
+ return dct_m
+
+ def create_mel_filter_bank(self):
+ """
+ Creates the Mel filter bank.
+
+ Returns:
+ the mel filter bank
+ """
+ num_fft_bins = int(self.__frame_len_padded / 2)
+ fft_bin_width = self.mfcc_params.sampling_freq / self.__frame_len_padded
+
+ mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, False)
+ mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, False)
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
+
+ this_bin = np.zeros(num_fft_bins)
+ mel_fbank = [0] * self.mfcc_params.num_fbank_bins
+
+ for bin_num in range(self.mfcc_params.num_fbank_bins):
+ left_mel = mel_low_freq + bin_num * mel_freq_delta
+ center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
+ right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta
+ first_index = last_index = -1
+
+ for i in range(num_fft_bins):
+ freq = (fft_bin_width * i)
+ mel = self.mel_scale(freq, False)
+ this_bin[i] = 0.0
+
+ if (mel > left_mel) and (mel < right_mel):
+ if mel <= center_mel:
+ weight = (mel - left_mel) / (center_mel - left_mel)
+ else:
+ weight = (right_mel - mel) / (right_mel - center_mel)
+
+ enorm = 2.0 / (self.inv_mel_scale(right_mel, False) - self.inv_mel_scale(left_mel, False))
+ weight *= enorm
+ this_bin[i] = weight
+
+ if first_index == -1:
+ first_index = i
+ last_index = i
+
+ self.__filter_bank_filter_first[bin_num] = first_index
+ self.__filter_bank_filter_last[bin_num] = last_index
+ mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
+ j = 0
+
+ for i in range(first_index, last_index + 1):
+ mel_fbank[bin_num][j] = this_bin[i]
+ j += 1
+
+ return mel_fbank
+
+
+class Preprocessor:
+
+ def __init__(self, mfcc, model_input_size, stride):
+ self.model_input_size = model_input_size
+ self.stride = stride
+
+ # Savitzky - Golay differential filters
+ self.__savgol_order1_coeffs = np.array([6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
+ 1.66666667e-02, -3.46944695e-18, -1.66666667e-02,
+ -3.33333333e-02, -5.00000000e-02, -6.66666667e-02])
+
+ self.savgol_order2_coeffs = np.array([0.06060606, 0.01515152, -0.01731602,
+ -0.03679654, -0.04329004, -0.03679654,
+ -0.01731602, 0.01515152, 0.06060606])
+
+ self.__mfcc_calc = mfcc
+
+ def __normalize(self, values):
+ """
+ Normalize values to mean 0 and std 1
+ """
+ ret_val = (values - np.mean(values)) / np.std(values)
+ return ret_val
+
+ def __get_features(self, features, mfcc_instance, audio_data):
+ idx = 0
+ while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
+ features.extend(mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)]))
+ idx += self.stride
+
+ def extract_features(self, audio_data):
+ """
+ Extracts the MFCC features, and calculates each features first and second order derivative.
+ The matrix returned should be sized appropriately for input to the model, based
+ on the model info specified in the MFCC instance.
+
+ Args:
+ mfcc_instance: The instance of MFCC used for this calculation
+ audio_data: the audio data to be used for this calculation
+ Returns:
+ the derived MFCC feature vector, sized appropriately for inference
+ """
+
+ num_samples_per_inference = ((self.model_input_size - 1)
+ * self.stride) + self.__mfcc_calc.mfcc_params.frame_len
+ if len(audio_data) < num_samples_per_inference:
+ raise ValueError("audio_data size for feature extraction is smaller than "
+ "the expected number of samples needed for inference")
+
+ features = []
+ self.__get_features(features, self.__mfcc_calc, np.asarray(audio_data))
+ features = np.reshape(np.array(features), (self.model_input_size, self.__mfcc_calc.mfcc_params.num_mfcc_feats))
+
+ mfcc_delta_np = np.zeros_like(features)
+ mfcc_delta2_np = np.zeros_like(features)
+
+ for i in range(features.shape[1]):
+ idelta = np.convolve(features[:, i], self.__savgol_order1_coeffs, 'same')
+ mfcc_delta_np[:, i] = (idelta)
+ ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
+ mfcc_delta2_np[:, i] = (ideltadelta)
+
+ features = np.concatenate((self.__normalize(features), self.__normalize(mfcc_delta_np),
+ self.__normalize(mfcc_delta2_np)), axis=1)
+
+ return np.float32(features)
diff --git a/python/pyarmnn/examples/speech_recognition/requirements.txt b/python/pyarmnn/examples/speech_recognition/requirements.txt
new file mode 100644
index 0000000000..4b8f3e6d24
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/requirements.txt
@@ -0,0 +1,2 @@
+numpy>=1.19.2
+soundfile>=0.10.3 \ No newline at end of file
diff --git a/python/pyarmnn/examples/speech_recognition/run_audio_file.py b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
new file mode 100644
index 0000000000..c7e4c6bc31
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
@@ -0,0 +1,94 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Automatic speech recognition with PyArmNN demo for processing audio clips to text."""
+
+import sys
+import os
+from argparse import ArgumentParser
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+from network_executor import ArmnnNetworkExecutor
+from utils import dict_labels
+from preprocess import MFCCParams, Preprocessor, MFCC
+from audio_capture import AudioCapture, ModelParams
+from audio_utils import decode_text, prepare_input_tensors, display_text
+
+
+def parse_args():
+ parser = ArgumentParser(description="ASR with PyArmNN")
+ parser.add_argument(
+ "--audio_file_path",
+ required=True,
+ type=str,
+ help="Path to the audio file to perform ASR",
+ )
+ parser.add_argument(
+ "--model_file_path",
+ required=True,
+ type=str,
+ help="Path to ASR model to use",
+ )
+ parser.add_argument(
+ "--labels_file_path",
+ required=True,
+ type=str,
+ help="Path to text file containing labels to map to model output",
+ )
+ parser.add_argument(
+ "--preferred_backends",
+ type=str,
+ nargs="+",
+ default=["CpuAcc", "CpuRef"],
+ help="""List of backends in order of preference for optimizing
+ subgraphs, falling back to the next backend in the list on unsupported
+ layers. Defaults to [CpuAcc, CpuRef]""",
+ )
+ return parser.parse_args()
+
+
+def main(args):
+ # Read command line args
+ audio_file = args.audio_file_path
+ model = ModelParams(args.model_file_path)
+ labels = dict_labels(args.labels_file_path)
+
+ # Create the ArmNN inference runner
+ network = ArmnnNetworkExecutor(model.path, args.preferred_backends)
+
+ audio_capture = AudioCapture(model)
+ buffer = audio_capture.from_audio_file(audio_file)
+
+ # Create the preprocessor
+ mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=128, mel_lo_freq=0, mel_hi_freq=8000,
+ num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_FFT=512)
+ mfcc = MFCC(mfcc_params)
+ preprocessor = Preprocessor(mfcc, model_input_size=1044, stride=160)
+
+ text = ""
+ current_r_context = ""
+ is_first_window = True
+
+ print("Processing Audio Frames...")
+ for audio_data in buffer:
+ # Prepare the input Tensors
+ input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+
+ # Run inference
+ output_result = network.run(input_tensors)
+
+ # Slice and Decode the text, and store the right context
+ current_r_context, text = decode_text(is_first_window, labels, output_result)
+
+ is_first_window = False
+
+ display_text(text)
+
+ print(current_r_context, flush=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/python/pyarmnn/examples/speech_recognition/tests/conftest.py b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
new file mode 100644
index 0000000000..730c291cfa
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
@@ -0,0 +1,34 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import ntpath
+
+import urllib.request
+
+import pytest
+
+script_dir = os.path.dirname(__file__)
+
+@pytest.fixture(scope="session")
+def test_data_folder(request):
+ """
+ This fixture returns path to folder with shared test resources among all tests
+ """
+
+ data_dir = os.path.join(script_dir, "testdata")
+
+ if not os.path.exists(data_dir):
+ os.mkdir(data_dir)
+
+ files_to_download = ["https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master"
+ "/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"]
+
+ for file in files_to_download:
+ path, filename = ntpath.split(file)
+ file_path = os.path.join(script_dir, "testdata", filename)
+ if not os.path.exists(file_path):
+ print("\nDownloading test file: " + file_path + "\n")
+ urllib.request.urlretrieve(file, file_path)
+
+ return data_dir
diff --git a/python/pyarmnn/examples/speech_recognition/tests/context.py b/python/pyarmnn/examples/speech_recognition/tests/context.py
new file mode 100644
index 0000000000..a810010e9f
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/context.py
@@ -0,0 +1,13 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import sys
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'common'))
+import utils as common_utils
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+import audio_capture
+import audio_utils
+import preprocess
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py b/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
new file mode 100644
index 0000000000..281d0df587
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
@@ -0,0 +1,17 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+import numpy as np
+
+from context import audio_capture
+
+
+def test_audio_file(test_data_folder):
+ audio_file = os.path.join(test_data_folder, "myVoiceIsMyPassportVerifyMe04.wav")
+ capture = audio_capture.AudioCapture(audio_capture.ModelParams(""))
+ buffer = capture.from_audio_file(audio_file)
+ audio_data = next(buffer)
+ assert audio_data.shape == (47712,)
+ assert audio_data.dtype == np.float32
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
new file mode 100644
index 0000000000..3b99e6504a
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
@@ -0,0 +1,28 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+
+import numpy as np
+
+from context import common_utils
+from context import audio_utils
+
+
+def test_labels(test_data_folder):
+ labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
+ labels = common_utils.dict_labels(labels_file)
+ assert len(labels) == 29
+ assert labels[26] == "\'"
+ assert labels[27] == r" "
+ assert labels[28] == "$"
+
+
+def test_decoder(test_data_folder):
+ labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
+ labels = common_utils.dict_labels(labels_file)
+
+ output_tensor = os.path.join(test_data_folder, "inf_out.npy")
+ encoded = np.load(output_tensor)
+ decoded_text = audio_utils.decode(encoded, labels)
+ assert decoded_text == "and he walkd immediately out of the apartiment by anothe"
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py b/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
new file mode 100644
index 0000000000..d692ab51c8
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
@@ -0,0 +1,286 @@
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+
+from context import preprocess
+
+test_wav = [
+ -3,0,1,-1,2,3,-2,2,
+ 1,-2,0,3,-1,8,3,2,
+ -1,-1,2,7,3,5,6,6,
+ 6,12,5,6,3,3,5,4,
+ 4,6,7,7,7,3,7,2,
+ 8,4,4,2,-4,-1,-1,-4,
+ 2,1,-1,-4,0,-7,-6,-2,
+ -5,1,-5,-1,-7,-3,-3,-7,
+ 0,-3,3,-5,0,1,-2,-2,
+ -3,-3,-7,-3,-2,-6,-5,-8,
+ -2,-8,4,-9,-4,-9,-5,-5,
+ -3,-9,-3,-9,-1,-7,-4,1,
+ -3,2,-8,-4,-4,-5,1,-3,
+ -1,0,-1,-2,-3,-2,-4,-1,
+ 1,-1,3,0,3,2,0,0,
+ 0,-3,1,1,0,8,3,4,
+ 1,5,6,4,7,3,3,0,
+ 3,6,7,6,4,5,9,9,
+ 5,5,8,1,6,9,6,6,
+ 7,1,8,1,5,0,5,5,
+ 0,3,2,7,2,-3,3,0,
+ 3,0,0,0,2,0,-1,-1,
+ -2,-3,-8,0,1,0,-3,-3,
+ -3,-2,-3,-3,-4,-6,-2,-8,
+ -9,-4,-1,-5,-3,-3,-4,-3,
+ -6,3,0,-1,-2,-9,-4,-2,
+ 2,-1,3,-5,-5,-2,0,-2,
+ 0,-1,-3,1,-2,9,4,5,
+ 2,2,1,0,-6,-2,0,0,
+ 0,-1,4,-4,3,-7,-1,5,
+ -6,-1,-5,4,3,9,-2,1,
+ 3,0,0,-2,1,2,1,1,
+ 0,3,2,-1,3,-3,7,0,
+ 0,3,2,2,-2,3,-2,2,
+ -3,4,-1,-1,-5,-1,-3,-2,
+ 1,-1,3,2,4,1,2,-2,
+ 0,2,7,0,8,-3,6,-3,
+ 6,1,2,-3,-1,-1,-1,1,
+ -2,2,1,2,0,-2,3,-2,
+ 3,-2,1,0,-3,-1,-2,-4,
+ -6,-5,-8,-1,-4,0,-3,-1,
+ -1,-1,0,-2,-3,-7,-1,0,
+ 1,5,0,5,1,1,-3,0,
+ -6,3,-8,4,-8,6,-6,1,
+ -6,-2,-5,-6,0,-5,4,-1,
+ 4,-2,1,2,1,0,-2,0,
+ 0,2,-2,2,-5,2,0,-2,
+ 1,-2,0,5,1,0,1,5,
+ 0,8,3,2,2,0,5,-2,
+ 3,1,0,1,0,-2,-1,-3,
+ 1,-1,3,0,3,0,-2,-1,
+ -4,-4,-4,-1,-4,-4,-3,-6,
+ -3,-7,-3,-1,-2,0,-5,-4,
+ -7,-3,-2,-2,1,2,2,8,
+ 5,4,2,4,3,5,0,3,
+ 3,6,4,2,2,-2,4,-2,
+ 3,3,2,1,1,4,-5,2,
+ -3,0,-1,1,-2,2,5,1,
+ 4,2,3,1,-1,1,0,6,
+ 0,-2,-1,1,-1,2,-5,-1,
+ -5,-1,-6,-3,-3,2,4,0,
+ -1,-5,3,-4,-1,-3,-4,1,
+ -4,1,-1,-1,0,-5,-4,-2,
+ -1,-1,-3,-7,-3,-3,4,4,
+]
+
+def test_mel_scale_function_with_htk_true():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.mel_scale(16, True)
+
+ assert np.isclose(mel, 25.470010570730597)
+
+
+def test_mel_scale_function_with_htk_false():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.mel_scale(16, False)
+
+ assert np.isclose(mel, 0.24)
+
+
+def test_inverse_mel_scale_function_with_htk_true():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.inv_mel_scale(16, True)
+
+ assert np.isclose(mel, 10.008767240008943)
+
+
+def test_inverse_mel_scale_function_with_htk_false():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel = mfcc_inst.inv_mel_scale(16, False)
+
+ assert np.isclose(mel, 1071.170287494467)
+
+
+def test_create_mel_filter_bank():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+
+ mel_filter_bank = mfcc_inst.create_mel_filter_bank()
+
+ assert len(mel_filter_bank) == 128
+
+ assert str(mel_filter_bank[0]) == "[0.02837754]"
+ assert str(mel_filter_bank[1]) == "[0.01438901 0.01398853]"
+ assert str(mel_filter_bank[2]) == "[0.02877802]"
+ assert str(mel_filter_bank[3]) == "[0.04236608]"
+ assert str(mel_filter_bank[4]) == "[0.00040047 0.02797707]"
+ assert str(mel_filter_bank[5]) == "[0.01478948 0.01358806]"
+ assert str(mel_filter_bank[50]) == "[0.03298853]"
+ assert str(mel_filter_bank[100]) == "[0.00260166 0.00588759 0.00914814 0.00798015 0.00476919 0.00158245]"
+
+
+def test_mfcc_compute():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.array(test_wav) / (2 ** 15)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ mfcc_feats = mfcc_inst.mfcc_compute(audio_data)
+
+ assert np.isclose((mfcc_feats[0]), -834.9656973095651)
+ assert np.isclose((mfcc_feats[1]), 21.026915475076322)
+ assert np.isclose((mfcc_feats[2]), 18.628541708201688)
+ assert np.isclose((mfcc_feats[3]), 7.341153529494758)
+ assert np.isclose((mfcc_feats[4]), 18.907974386153214)
+ assert np.isclose((mfcc_feats[5]), -5.360387487466194)
+ assert np.isclose((mfcc_feats[6]), 6.523572638527085)
+ assert np.isclose((mfcc_feats[7]), -11.270643644983316)
+ assert np.isclose((mfcc_feats[8]), 8.375177203773777)
+ assert np.isclose((mfcc_feats[9]), 12.06721844362991)
+ assert np.isclose((mfcc_feats[10]), 8.30815892468875)
+ assert np.isclose((mfcc_feats[11]), -13.499911910889917)
+ assert np.isclose((mfcc_feats[12]), -18.176121251436165)
+
+
+def test_sliding_window_for_small_num_samples():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ mode_input_size = 9
+ stride = 160
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.array(test_wav) / (2 ** 15)
+
+ full_audio_data = np.tile(audio_data, 9)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
+
+ input_tensor = preprocessor.extract_features(full_audio_data)
+
+ assert np.isclose(input_tensor[0][0], -3.4660944830426454)
+ assert np.isclose(input_tensor[0][1], 0.3587718932127629)
+ assert np.isclose(input_tensor[0][2], 0.3480551325669172)
+ assert np.isclose(input_tensor[0][3], 0.2976191917228921)
+ assert np.isclose(input_tensor[0][4], 0.3493037340849936)
+ assert np.isclose(input_tensor[0][5], 0.2408643285767937)
+ assert np.isclose(input_tensor[0][6], 0.2939659585037282)
+ assert np.isclose(input_tensor[0][7], 0.2144552669573928)
+ assert np.isclose(input_tensor[0][8], 0.302239565899944)
+ assert np.isclose(input_tensor[0][9], 0.3187368787077345)
+ assert np.isclose(input_tensor[0][10], 0.3019401051295793)
+ assert np.isclose(input_tensor[0][11], 0.20449412797602678)
+
+ assert np.isclose(input_tensor[0][38], -0.18751440767749533)
+
+
+def test_sliding_window_for_wav_2_letter_sized_input():
+ samp_freq = 16000
+ frame_len_ms = 32
+ frame_len_samples = samp_freq * frame_len_ms * 0.001
+ num_mfcc_feats = 13
+ mode_input_size = 296
+ stride = 160
+ num_fbank_bins = 128
+ mel_lo_freq = 0
+ mil_hi_freq = 8000
+ use_htk = False
+ n_FFT = 512
+
+ audio_data = np.zeros(47712, dtype=int)
+
+ mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
+ frame_len_samples, use_htk, n_FFT)
+
+ mfcc_inst = preprocess.MFCC(mfcc_params)
+ preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
+
+ input_tensor = preprocessor.extract_features(audio_data)
+
+ assert len(input_tensor[0]) == 39
+ assert len(input_tensor) == 296
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy b/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
new file mode 100644
index 0000000000..a6f9ec0c70
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
Binary files differ
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt b/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
new file mode 100644
index 0000000000..d7485b7da2
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
@@ -0,0 +1,29 @@
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+'
+
+$ \ No newline at end of file