aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoralexander <alexander.efremov@arm.com>2021-07-16 11:30:56 +0100
committerJim Flynn <jim.flynn@arm.com>2022-02-04 09:55:21 +0000
commitf42f56870c6201a876f025a423eb5540d7438e83 (patch)
treee8e57e371c851cbb9a51a2f3ec35059addd2e93e
parent9d74ba6e85a043e9603445e062315f5c4965fbd6 (diff)
downloadarmnn-f42f56870c6201a876f025a423eb5540d7438e83.tar.gz
MLECO-2079 Adding the python KWS example
Signed-off-by: Eanna O Cathain <eanna.ocathain@arm.com> Change-Id: Ie1463aaeb5e3cade22df8f560ae99a8e1c4a9c17
-rw-r--r--python/pyarmnn/examples/common/audio_capture.py149
-rw-r--r--python/pyarmnn/examples/common/cv_utils.py8
-rw-r--r--python/pyarmnn/examples/common/mfcc.py (renamed from python/pyarmnn/examples/speech_recognition/preprocess.py)152
-rw-r--r--python/pyarmnn/examples/common/tests/context.py7
-rw-r--r--python/pyarmnn/examples/common/utils.py69
-rw-r--r--python/pyarmnn/examples/keyword_spotting/README.MD189
-rw-r--r--python/pyarmnn/examples/keyword_spotting/__init__.py0
-rw-r--r--python/pyarmnn/examples/keyword_spotting/audio_utils.py31
-rw-r--r--python/pyarmnn/examples/keyword_spotting/requirements.txt5
-rw-r--r--python/pyarmnn/examples/keyword_spotting/run_audio_classification.py136
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_file.py170
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_stream.py175
-rw-r--r--python/pyarmnn/examples/speech_recognition/README.md39
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_capture.py56
-rw-r--r--python/pyarmnn/examples/speech_recognition/audio_utils.py53
-rw-r--r--python/pyarmnn/examples/speech_recognition/requirements.txt5
-rw-r--r--python/pyarmnn/examples/speech_recognition/run_audio_file.py48
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/conftest.py58
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/context.py13
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py17
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_decoder.py15
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py286
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npybin4420 -> 0 bytes
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/inference_output.npybin0 -> 2999 bytes
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/quick_brown_fox_16000khz.wavbin196728 -> 0 bytes
-rw-r--r--python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt29
-rw-r--r--python/pyarmnn/examples/speech_recognition/wav2letter_mfcc.py91
-rw-r--r--python/pyarmnn/examples/tests/conftest.py (renamed from python/pyarmnn/examples/common/tests/conftest.py)15
-rw-r--r--python/pyarmnn/examples/tests/context.py22
-rw-r--r--python/pyarmnn/examples/tests/test_common_utils.py (renamed from python/pyarmnn/examples/common/tests/test_utils.py)0
-rw-r--r--python/pyarmnn/examples/tests/test_mfcc.py247
-rw-r--r--python/pyarmnn/examples/tests/test_network_executor.py (renamed from python/pyarmnn/examples/common/tests/test_network_executor.py)4
-rw-r--r--python/pyarmnn/examples/tests/testdata/labelmap.txt9
33 files changed, 1289 insertions, 809 deletions
diff --git a/python/pyarmnn/examples/common/audio_capture.py b/python/pyarmnn/examples/common/audio_capture.py
new file mode 100644
index 0000000000..1bd53b4473
--- /dev/null
+++ b/python/pyarmnn/examples/common/audio_capture.py
@@ -0,0 +1,149 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+"""Contains CaptureAudioStream class for capturing chunks of audio data from incoming
+ stream and generic capture_audio function for capturing from files."""
+import collections
+import time
+from queue import Queue
+from typing import Generator
+
+import numpy as np
+import sounddevice as sd
+import soundfile as sf
+
+AudioCaptureParams = collections.namedtuple('AudioCaptureParams',
+ ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono'])
+
+
+def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]:
+ """Creates a generator that yields audio data from a file. Data is padded with
+ zeros if necessary to make up minimum number of samples.
+ Args:
+ audio_file_path: Path to audio file provided by user.
+ params_tuple: Sampling parameters for model used
+ Yields:
+ Blocks of audio data of minimum sample size.
+ """
+ with sf.SoundFile(audio_file_path) as audio_file:
+ for block in audio_file.blocks(
+ blocksize=params_tuple.min_samples,
+ dtype=params_tuple.dtype,
+ always_2d=True,
+ fill_value=0,
+ overlap=params_tuple.overlap
+ ):
+ if params_tuple.mono and block.shape[0] > 1:
+ block = np.mean(block, dtype=block.dtype, axis=1)
+ yield block
+
+
+class CaptureAudioStream:
+
+ def __init__(self, audio_capture_params):
+ self.audio_capture_params = audio_capture_params
+ self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype(
+ dtype=self.audio_capture_params.dtype)
+ self.is_active = True
+ self.is_first_window = True
+ self.duration = False
+ self.block_count = 0
+ self.current_block = 0
+ self.queue = Queue(2)
+
+ def set_stream_defaults(self):
+ """Discovers input devices on the system and sets default stream parameters."""
+ print(sd.query_devices())
+ device = input("Select input device by index or name: ")
+
+ try:
+ sd.default.device = int(device)
+ except ValueError:
+ sd.default.device = str(device)
+
+ sd.default.samplerate = self.audio_capture_params.sampling_freq
+ sd.default.blocksize = self.audio_capture_params.min_samples
+ sd.default.dtype = self.audio_capture_params.dtype
+ sd.default.channels = 1 if self.audio_capture_params.mono else 2
+
+ def set_recording_duration(self, duration):
+ """Sets a time duration (in integer seconds) for recording audio. Total time duration is
+ adjusted to a minimum based on the parameters of the model used. Durations less than 1
+ result in endless recording.
+
+ Args:
+ duration (int): User-provided command line argument for time duration of recording.
+ """
+ if duration > 0:
+ min_duration = int(
+ np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq)
+ )
+ if duration < min_duration:
+ print(f"Minimum duration must be {min_duration} seconds of audio")
+ print(f"Setting minimum recording duration...")
+ duration = min_duration
+
+ print(f"Recording duration is {duration} seconds")
+ self.duration = self.audio_capture_params.sampling_freq * duration
+ self.block_count, remainder_samples = divmod(
+ self.duration, self.audio_capture_params.min_samples
+ )
+
+ if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq:
+ self.block_count += 1
+ else:
+ self.duration = False # Record forever
+
+ def countdown(self, delay=3):
+ """3 second countdown prior to recording audio."""
+ print("Beginning recording in...")
+ for i in range(delay, 0, -1):
+ print(f"{i}...")
+ time.sleep(1)
+
+ def update(self):
+ """If a duration has been set, increments a counter to update the number of blocks of audio
+ data left to be collected. The stream is deactivated upon reaching the maximum block count
+ determined by the duration.
+ """
+ if self.duration:
+ self.current_block += 1
+ if self.current_block == self.block_count:
+ self.is_active = False
+
+ def capture_data(self):
+ """Gets the next window of audio data by retrieving the newest data from a queue and
+ shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported.
+ """
+ new_data = self.queue.get()
+
+ if self.is_first_window or self.audio_capture_params.overlap == 0:
+ self.collection[:self.audio_capture_params.min_samples] = new_data[:]
+
+ elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples:
+ #
+ self.collection[0:self.audio_capture_params.overlap] = \
+ self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap):
+ self.audio_capture_params.min_samples]
+
+ self.collection[self.audio_capture_params.overlap:(
+ self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:]
+ else:
+ raise ValueError(
+ "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples))
+ audio_data = self.collection[0:self.audio_capture_params.min_samples]
+ return np.asarray(audio_data).astype(self.audio_capture_params.dtype)
+
+ def callback(self, data, frames, time, status):
+ """Places audio data from active stream into a queue for processing.
+ Update counter if recording duration is finite.
+ """
+
+ if self.duration:
+ self.update()
+
+ if self.audio_capture_params.mono:
+ audio_data = data.copy().flatten()
+ else:
+ audio_data = data.copy()
+
+ self.queue.put(audio_data)
diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py
index fd848b8b0f..e12ff50548 100644
--- a/python/pyarmnn/examples/common/cv_utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""
@@ -14,7 +14,7 @@ import numpy as np
import pyarmnn as ann
-def preprocess(frame: np.ndarray, input_binding_info: tuple):
+def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
"""
Takes a frame, resizes, swaps channels and converts data type to match
model input layer. The converted frame is wrapped in a const tensor
@@ -23,6 +23,7 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple):
Args:
frame: Captured frame from video.
input_binding_info: Contains shape and data type of model input layer.
+ is_normalised: if the input layer expects normalised data
Returns:
Input tensor.
@@ -34,7 +35,8 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple):
# Expand dimensions and convert data type to match model input
if input_binding_info[1].GetDataType() == ann.DataType_Float32:
data_type = np.float32
- resized_frame = resized_frame.astype("float32")/255
+ if is_normalised:
+ resized_frame = resized_frame.astype("float32")/255
else:
data_type = np.uint8
diff --git a/python/pyarmnn/examples/speech_recognition/preprocess.py b/python/pyarmnn/examples/common/mfcc.py
index 553ddba5de..2bab669fb7 100644
--- a/python/pyarmnn/examples/speech_recognition/preprocess.py
+++ b/python/pyarmnn/examples/common/mfcc.py
@@ -1,22 +1,13 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
import numpy as np
+import collections
-
-class MFCCParams:
- def __init__(self, sampling_freq, num_fbank_bins,
- mel_lo_freq, mel_hi_freq, num_mfcc_feats, frame_len, use_htk_method, n_FFT):
- self.sampling_freq = sampling_freq
- self.num_fbank_bins = num_fbank_bins
- self.mel_lo_freq = mel_lo_freq
- self.mel_hi_freq = mel_hi_freq
- self.num_mfcc_feats = num_mfcc_feats
- self.frame_len = frame_len
- self.use_htk_method = use_htk_method
- self.n_FFT = n_FFT
+MFCCParams = collections.namedtuple('MFCCParams', ['sampling_freq', 'num_fbank_bins', 'mel_lo_freq', 'mel_hi_freq',
+ 'num_mfcc_feats', 'frame_len', 'use_htk_method', 'n_fft'])
class MFCC:
@@ -27,21 +18,21 @@ class MFCC:
self.MIN_LOG_HZ = 1000.0
self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
self.LOG_STEP = 1.8562979903656 / 27.0
- self.__frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
- self.__filter_bank_initialised = False
- self.__frame = np.zeros(self.__frame_len_padded)
- self.__buffer = np.zeros(self.__frame_len_padded)
- self.__filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
- self.__filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
+ self._frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
+ self._filter_bank_initialised = False
+ self.__frame = np.zeros(self._frame_len_padded)
+ self.__buffer = np.zeros(self._frame_len_padded)
+ self._filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
+ self._filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
- self.__dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
+ self._dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
self.__mel_filter_bank = self.create_mel_filter_bank()
- self.__np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_FFT / 2) + 1])
+ self._np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_fft / 2) + 1])
for i in range(self.mfcc_params.num_fbank_bins):
k = 0
- for j in range(int(self.__filter_bank_filter_first[i]), int(self.__filter_bank_filter_last[i]) + 1):
- self.__np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
+ for j in range(int(self._filter_bank_filter_first[i]), int(self._filter_bank_filter_last[i]) + 1):
+ self._np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
k += 1
def mel_scale(self, freq, use_htk_method):
@@ -84,6 +75,14 @@ class MFCC:
freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
return freq
+ def spectrum_calc(self, audio_data):
+ return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data,
+ self.mfcc_params.n_fft))
+
+ def log_mel(self, mel_energy):
+ mel_energy += 1e-10 # Avoid division by zero
+ return np.log(mel_energy)
+
def mfcc_compute(self, audio_data):
"""
Extracts the MFCC for a single frame.
@@ -96,22 +95,14 @@ class MFCC:
"""
if len(audio_data) != self.mfcc_params.frame_len:
raise ValueError(
- f"audio_data buffer size {len(audio_data)} does not match the frame length {self.mfcc_params.frame_len}")
+ f"audio_data buffer size {len(audio_data)} does not match frame length {self.mfcc_params.frame_len}")
audio_data = np.array(audio_data)
- spec = np.abs(np.fft.rfft(np.hanning(self.mfcc_params.n_FFT + 1)[0:self.mfcc_params.n_FFT] * audio_data,
- self.mfcc_params.n_FFT)) ** 2
- mel_energy = np.dot(self.__np_mel_bank.astype(np.float32),
+ spec = self.spectrum_calc(audio_data)
+ mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
np.transpose(spec).astype(np.float32))
-
- mel_energy += 1e-10
- log_mel_energy = 10.0 * np.log10(mel_energy)
- top_db = 80.0
-
- log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
-
- mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
-
+ log_mel_energy = self.log_mel(mel_energy)
+ mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
return mfcc_feats
def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
@@ -125,19 +116,21 @@ class MFCC:
Returns:
the DCT matrix
"""
+
dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
for k in range(num_mfcc_feats):
for n in range(num_fbank_bins):
- if k == 0:
- dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (4 * num_fbank_bins)) * np.cos(
- (np.pi / num_fbank_bins) * (n + 0.5) * k)
- else:
- dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (2 * num_fbank_bins)) * np.cos(
- (np.pi / num_fbank_bins) * (n + 0.5) * k)
-
+ dct_m[(k * num_fbank_bins) + n] = (np.sqrt(2 / num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
return dct_m
+ def mel_norm(self, weight, right_mel, left_mel):
+ """
+ Placeholder function over-ridden in child class
+ """
+ return weight
+
def create_mel_filter_bank(self):
"""
Creates the Mel filter bank.
@@ -145,16 +138,17 @@ class MFCC:
Returns:
the mel filter bank
"""
- num_fft_bins = int(self.__frame_len_padded / 2)
- fft_bin_width = self.mfcc_params.sampling_freq / self.__frame_len_padded
+ # FFT calculations are greatly accelerated for frame lengths which are powers of 2
+ # Frames are padded and FFT bin width/length calculated accordingly
+ num_fft_bins = int(self._frame_len_padded / 2)
+ fft_bin_width = self.mfcc_params.sampling_freq / self._frame_len_padded
- mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, False)
- mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, False)
+ mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, self.mfcc_params.use_htk_method)
+ mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, self.mfcc_params.use_htk_method)
mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
this_bin = np.zeros(num_fft_bins)
mel_fbank = [0] * self.mfcc_params.num_fbank_bins
-
for bin_num in range(self.mfcc_params.num_fbank_bins):
left_mel = mel_low_freq + bin_num * mel_freq_delta
center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
@@ -163,7 +157,7 @@ class MFCC:
for i in range(num_fft_bins):
freq = (fft_bin_width * i)
- mel = self.mel_scale(freq, False)
+ mel = self.mel_scale(freq, self.mfcc_params.use_htk_method)
this_bin[i] = 0.0
if (mel > left_mel) and (mel < right_mel):
@@ -172,16 +166,14 @@ class MFCC:
else:
weight = (right_mel - mel) / (right_mel - center_mel)
- enorm = 2.0 / (self.inv_mel_scale(right_mel, False) - self.inv_mel_scale(left_mel, False))
- weight *= enorm
- this_bin[i] = weight
+ this_bin[i] = self.mel_norm(weight, right_mel, left_mel)
if first_index == -1:
first_index = i
last_index = i
- self.__filter_bank_filter_first[bin_num] = first_index
- self.__filter_bank_filter_last[bin_num] = last_index
+ self._filter_bank_filter_first[bin_num] = first_index
+ self._filter_bank_filter_last[bin_num] = last_index
mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
j = 0
@@ -192,69 +184,55 @@ class MFCC:
return mel_fbank
-class Preprocessor:
+class AudioPreprocessor:
def __init__(self, mfcc, model_input_size, stride):
self.model_input_size = model_input_size
self.stride = stride
+ self._mfcc_calc = mfcc
- # Savitzky - Golay differential filters
- self.__savgol_order1_coeffs = np.array([6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
- 1.66666667e-02, -3.46944695e-18, -1.66666667e-02,
- -3.33333333e-02, -5.00000000e-02, -6.66666667e-02])
-
- self.savgol_order2_coeffs = np.array([0.06060606, 0.01515152, -0.01731602,
- -0.03679654, -0.04329004, -0.03679654,
- -0.01731602, 0.01515152, 0.06060606])
-
- self.__mfcc_calc = mfcc
-
- def __normalize(self, values):
+ def _normalize(self, values):
"""
Normalize values to mean 0 and std 1
"""
ret_val = (values - np.mean(values)) / np.std(values)
return ret_val
- def __get_features(self, features, mfcc_instance, audio_data):
+ def _get_features(self, features, mfcc_instance, audio_data):
idx = 0
while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
- features.extend(mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)]))
+ current_frame_feats = mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)])
+ features.extend(current_frame_feats)
idx += self.stride
+ def mfcc_delta_calc(self, features):
+ """
+ Placeholder function over-ridden in child class
+ """
+ return features
+
def extract_features(self, audio_data):
"""
- Extracts the MFCC features, and calculates each features first and second order derivative.
+ Extracts the MFCC features. Also calculates each features first and second order derivatives
+ if the mfcc_delta_calc() function has been implemented by a child class.
The matrix returned should be sized appropriately for input to the model, based
on the model info specified in the MFCC instance.
Args:
- mfcc_instance: The instance of MFCC used for this calculation
audio_data: the audio data to be used for this calculation
Returns:
the derived MFCC feature vector, sized appropriately for inference
"""
num_samples_per_inference = ((self.model_input_size - 1)
- * self.stride) + self.__mfcc_calc.mfcc_params.frame_len
+ * self.stride) + self._mfcc_calc.mfcc_params.frame_len
+
if len(audio_data) < num_samples_per_inference:
raise ValueError("audio_data size for feature extraction is smaller than "
"the expected number of samples needed for inference")
features = []
- self.__get_features(features, self.__mfcc_calc, np.asarray(audio_data))
- features = np.reshape(np.array(features), (self.model_input_size, self.__mfcc_calc.mfcc_params.num_mfcc_feats))
-
- mfcc_delta_np = np.zeros_like(features)
- mfcc_delta2_np = np.zeros_like(features)
-
- for i in range(features.shape[1]):
- idelta = np.convolve(features[:, i], self.__savgol_order1_coeffs, 'same')
- mfcc_delta_np[:, i] = (idelta)
- ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
- mfcc_delta2_np[:, i] = (ideltadelta)
-
- features = np.concatenate((self.__normalize(features), self.__normalize(mfcc_delta_np),
- self.__normalize(mfcc_delta2_np)), axis=1)
-
+ self._get_features(features, self._mfcc_calc, np.asarray(audio_data))
+ features = np.reshape(np.array(features), (self.model_input_size, self._mfcc_calc.mfcc_params.num_mfcc_feats))
+ features = self.mfcc_delta_calc(features)
return np.float32(features)
diff --git a/python/pyarmnn/examples/common/tests/context.py b/python/pyarmnn/examples/common/tests/context.py
deleted file mode 100644
index 72246c03bf..0000000000
--- a/python/pyarmnn/examples/common/tests/context.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import os
-import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-
-import cv_utils
-import network_executor
-import utils
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
index cf09fdefb8..d4dadf80a4 100644
--- a/python/pyarmnn/examples/common/utils.py
+++ b/python/pyarmnn/examples/common/utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Contains helper functions that can be used across the example apps."""
@@ -8,6 +8,7 @@ import errno
from pathlib import Path
import numpy as np
+import pyarmnn as ann
def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
@@ -39,3 +40,69 @@ def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
else:
labels[idx] = line.strip("\n")
return labels
+
+
+def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+ """
+ Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
+ input tensors.
+
+ Args:
+ audio_data: The audio data to process
+ mfcc_instance: the mfcc class instance
+ input_binding_info: the model input binding info
+ mfcc_preprocessor: the mfcc preprocessor instance
+ Returns:
+ input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ """
+
+ data_type = input_binding_info[1].GetDataType()
+ input_tensor = mfcc_preprocessor.extract_features(audio_data)
+ if data_type != ann.DataType_Float32:
+ input_tensor = quantize_input(input_tensor, input_binding_info)
+ input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+ return input_tensors
+
+
+def quantize_input(data, input_binding_info):
+ """Quantize the float input to (u)int8 ready for inputting to model."""
+ if data.ndim != 2:
+ raise RuntimeError("Audio data must have 2 dimensions for quantization")
+
+ quant_scale = input_binding_info[1].GetQuantizationScale()
+ quant_offset = input_binding_info[1].GetQuantizationOffset()
+ data_type = input_binding_info[1].GetDataType()
+
+ if data_type == ann.DataType_QAsymmS8:
+ data_type = np.int8
+ elif data_type == ann.DataType_QAsymmU8:
+ data_type = np.uint8
+ else:
+ raise ValueError("Could not quantize data to required data type")
+
+ d_min = np.iinfo(data_type).min
+ d_max = np.iinfo(data_type).max
+
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = np.clip(data[row, col], d_min, d_max)
+ data = data.astype(data_type)
+ return data
+
+
+def dequantize_output(data, output_binding_info):
+ """Dequantize the (u)int8 output to float"""
+
+ if output_binding_info[1].IsQuantized():
+ if data.ndim != 2:
+ raise RuntimeError("Data must have 2 dimensions for quantization")
+
+ quant_scale = output_binding_info[1].GetQuantizationScale()
+ quant_offset = output_binding_info[1].GetQuantizationOffset()
+
+ data = data.astype(float)
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] - quant_offset)*quant_scale
+ return data
diff --git a/python/pyarmnn/examples/keyword_spotting/README.MD b/python/pyarmnn/examples/keyword_spotting/README.MD
new file mode 100644
index 0000000000..4299fa0fd4
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/README.MD
@@ -0,0 +1,189 @@
+# Keyword Spotting with PyArmNN
+
+This sample application guides the user to perform Keyword Spotting (KWS) with PyArmNN API.
+
+## Prerequisites
+
+### PyArmNN
+
+Before proceeding to the next steps, make sure that you have successfully installed the newest version of PyArmNN on your system by following the instructions in the README of the PyArmNN root directory.
+
+You can verify that PyArmNN library is installed and check PyArmNN version using:
+
+```bash
+$ pip show pyarmnn
+```
+
+You can also verify it by running the following and getting output similar to below:
+
+```bash
+$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
+'26.0.0'
+```
+
+### Dependencies
+
+Install the PortAudio package:
+
+```bash
+$ sudo apt-get install libsndfile1 libportaudio2
+```
+
+Install the required Python modules:
+
+```bash
+$ pip install -r requirements.txt
+```
+
+### Model
+
+The model we are using is the [DS CNN Large](https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/) which can be found in the [Arm Model Zoo repository](
+https://github.com/ARM-software/ML-zoo/tree/master/models).
+
+A small selection of suitable wav files containing keywords can be found [here](https://git.mlplatform.org/ml/ethos-u/ml-embedded-evaluation-kit.git/plain/resources/kws/samples/).
+
+Labels for this model are defined within run_audio_classification.py.
+
+## Performing Keyword Spotting
+
+### Processing Audio Files
+
+Please ensure that your audio file has a sampling rate of 16000Hz.
+
+To run KWS on an audio file, use the following command:
+
+```bash
+$ python run_audio_classification.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model>
+```
+
+You may also add the optional flags:
+
+* `--preferred_backends`
+
+ * Takes the preferred backends in preference order, separated by whitespace. For example, passing in "CpuAcc CpuRef" will be read as list ["CpuAcc", "CpuRef"] (defaults to this list)
+
+ * CpuAcc represents the CPU backend
+
+ * GpuAcc represents the GPU backend
+
+ * CpuRef represents the CPU reference kernels
+
+* `--help` prints all available options to screen
+
+
+### Processing Audio Streams
+
+To run KWS on a live audio stream, use the following command:
+
+```bash
+$ python run_audio_classification.py --model_file_path <path/to/your_model> --duration (optional)
+```
+You will be prompted to select an input microphone and inference will commence
+after 3 seconds.
+
+
+You may also add the following optional flag in addition to those for run_audio_file.py:
+
+* `--duration`
+
+ * Integer number of seconds to perform inference. Default is to continue indefinitely.
+
+## Application Overview
+
+1. [Initialization](#initialization)
+
+2. [Creating a network](#creating-a-network)
+
+3. [Keyword Spotting Pipeline](#keyword-spotting-pipeline)
+
+### Initialization
+
+The application parses the supplied user arguments and loads the audio file or stream in chunks through the `capture_audio()` method which accepts sampling criteria as an `AudioCaptureParams` tuple.
+
+With KWS from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size defined in AudioCaptureParams.
+
+MFCC features are extracted from each block based on criteria defined in the `MFCCParams` tuple. These extracted features constitute the input tensors for the model.
+
+To interpret the inference result of the loaded network; the application passes the label dictionary defined in run_audio_classification.py to a decoder and displays the result.
+
+### Creating a network
+
+A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
+
+Arm NN supports optimized execution on multiple CPU, GPU, and Ethos-N devices. Before executing a graph, the application must select the appropriate device context by using `IRuntime()` to create a runtime context with default options. We can optimize the imported graph by specifying a list of backends in order of preference and implementing backend-specific optimizations, identified by a unique string, for example CpuAcc, GpuAcc, CpuRef represent the accelerated CPU and GPU backends and the CPU reference kernels respectively.
+
+Arm NN splits the entire graph into subgraphs based on these backends. Each subgraph is then optimized, and the corresponding subgraph in the original graph is substituted with its optimized version.
+
+The `Optimize()` function optimizes the graph for inference, then `LoadNetwork()` loads the optimized network onto the compute device. The `LoadNetwork()` function also creates the backend-specific workloads for the layers and a backend-specific workload factory.
+
+Parsers extract the input information for the network. The `GetSubgraphInputTensorNames()` function extracts all the input names and the `GetNetworkInputBindingInfo()` function obtains the input binding information of the graph. The input binding information contains all the essential information about the input. This information is a tuple consisting of integer identifiers for bindable layers and tensor information (data type, quantization info, dimension count, total elements).
+
+Similarly, we can get the output binding information for an output layer by using the parser to retrieve output tensor names and calling the `GetNetworkOutputBindingInfo()` function
+
+For this application, the main point of contact with PyArmNN is through the `ArmnnNetworkExecutor` class, which will handle the network creation step for you.
+
+```python
+# common/network_executor.py
+# The provided kws model is in .tflite format so we use TfLiteParser() to import the graph
+if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+network = parser.CreateNetworkFromBinaryFile(model_file)
+...
+# Optimize the network for the list of preferred backends
+opt_network, messages = ann.Optimize(
+ network, preferred_backends, self.runtime.GetDeviceSpec(), ann.OptimizerOptions()
+ )
+# Load the optimized network onto the runtime device
+self.network_id, _ = self.runtime.LoadNetwork(opt_network)
+# Get the input and output binding information
+self.input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+self.output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+```
+
+### Keyword Spotting pipeline
+
+
+Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) are extracted based on criteria defined in the MFCCParams tuple and associated`MFCC Class`.
+MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+
+The `MFCC` class is used in conjunction with the `AudioPreProcessor` class to extract and process MFCC features from a given audio frame.
+
+
+After all the MFCCs needed for an inference have been extracted from the audio data they constitute the input tensors that will be classified by an `ArmnnNetworkExecutor`object.
+
+```python
+# mfcc.py
+# Extract MFCC features from audio_data
+audio_data.resize(self._frame_len_padded)
+spec = self.spectrum_calc(audio_data)
+mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
+ np.transpose(spec).astype(np.float32))
+log_mel_energy = self.log_mel(mel_energy)
+mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
+
+
+```python
+# audio_utils.py
+# Quantize the input data and create input tensors with PyArmNN
+input_tensor = quantize_input(input_tensor, input_binding_info)
+input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+```
+
+Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
+
+After creating the workload tensors, the compute device performs inference for the loaded network by using the `EnqueueWorkload()` function of the runtime context. Calling the `workload_tensors_to_ndarray()` function obtains the inference results as a list of ndarrays.
+
+```python
+# common/network_executor.py
+status = runtime.EnqueueWorkload(net_id, input_tensors, self.output_tensors)
+self.output_result = ann.workload_tensors_to_ndarray(self.output_tensors)
+```
+
+The output from the inference must be decoded to obtain the recognised classification. A simple greedy decoder classifies the results by taking the highest element of the output as a key for the labels dictionary. The value returned is a keyword or unknown/silence which is appended to a list along with the calculated probability. The list elements are displayed on the console if they exceed the threshold value specified in main().
+
+
+## Next steps
+
+Having now gained a solid understanding of performing keyword spotting with PyArmNN, you are able to take control and create your own application. We suggest to first implement your own network, which can be done by updating the parameters of `AudioCaptureParams` and `MFCC_Params` to match your custom model. The `ArmnnNetworkExecutor` class will handle the network optimisation and loading for you.
+
+An important factor in improving accuracy of the generated output is providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
diff --git a/python/pyarmnn/examples/keyword_spotting/__init__.py b/python/pyarmnn/examples/keyword_spotting/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/__init__.py
diff --git a/python/pyarmnn/examples/keyword_spotting/audio_utils.py b/python/pyarmnn/examples/keyword_spotting/audio_utils.py
new file mode 100644
index 0000000000..723c0e38f6
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/audio_utils.py
@@ -0,0 +1,31 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Utilities for speech recognition apps."""
+
+import numpy as np
+
+
+def decode(model_output: np.ndarray, labels: dict) -> list:
+ """Decodes the integer encoded results from inference into a string.
+
+ Args:
+ model_output: Results from running inference.
+ labels: Dictionary of labels keyed on the classification index.
+
+ Returns:
+ Decoded string.
+ """
+ results = [labels[np.argmax(model_output)], model_output[0][0][np.argmax(model_output)]]
+
+ return results
+
+
+def display_text(text: list):
+ """Presents the results on the console.
+
+ Args:
+ text: Results of performing ASR on the input audio data.
+ """
+ print('Classification: %s' % text[0])
+ print('Probability: %s' % text[1])
diff --git a/python/pyarmnn/examples/keyword_spotting/requirements.txt b/python/pyarmnn/examples/keyword_spotting/requirements.txt
new file mode 100644
index 0000000000..96782eafd0
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/requirements.txt
@@ -0,0 +1,5 @@
+numpy>=1.19.2
+soundfile>=0.10.3
+pytest==6.2.4
+pytest-allclose==1.0.0
+sounddevice==0.4.2
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
new file mode 100644
index 0000000000..6dfa4cc806
--- /dev/null
+++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
@@ -0,0 +1,136 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files."""
+
+import sys
+import os
+from argparse import ArgumentParser
+
+import numpy as np
+import sounddevice as sd
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+from network_executor import ArmnnNetworkExecutor
+from utils import prepare_input_tensors, dequantize_output
+from mfcc import AudioPreprocessor, MFCC, MFCCParams
+from audio_utils import decode, display_text
+from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
+
+# Model Specific Labels
+labels = {0: 'silence',
+ 1: 'unknown',
+ 2: 'yes',
+ 3: 'no',
+ 4: 'up',
+ 5: 'down',
+ 6: 'left',
+ 7: 'right',
+ 8: 'on',
+ 9: 'off',
+ 10: 'stop',
+ 11: 'go'}
+
+
+def parse_args():
+ parser = ArgumentParser(description="KWS with PyArmNN")
+ parser.add_argument(
+ "--audio_file_path",
+ required=False,
+ type=str,
+ help="Path to the audio file to perform KWS",
+ )
+ parser.add_argument(
+ "--duration",
+ type=int,
+ default=0,
+ help="""Duration for recording audio in seconds. Values <= 0 result in infinite
+ recording. Defaults to infinite.""",
+ )
+ parser.add_argument(
+ "--model_file_path",
+ required=True,
+ type=str,
+ help="Path to KWS model to use",
+ )
+ parser.add_argument(
+ "--preferred_backends",
+ type=str,
+ nargs="+",
+ default=["CpuAcc", "CpuRef"],
+ help="""List of backends in order of preference for optimizing
+ subgraphs, falling back to the next backend in the list on unsupported
+ layers. Defaults to [CpuAcc, CpuRef]""",
+ )
+ return parser.parse_args()
+
+
+def recognise_speech(audio_data, network, preprocessor, threshold):
+ # Prepare the input Tensors
+ input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ # Run inference
+ output_result = network.run(input_tensors)
+
+ dequantized_result = []
+ for index, ofm in enumerate(output_result):
+ dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
+
+ # Decode the text and display result if above threshold
+ decoded_result = decode(dequantized_result, labels)
+
+ if decoded_result[1] > threshold:
+ display_text(decoded_result)
+
+
+def main(args):
+ # Read command line args and invoke mic streaming if no file path supplied
+ audio_file = args.audio_file_path
+ if args.audio_file_path:
+ streaming_enabled = False
+ else:
+ streaming_enabled = True
+ # Create the ArmNN inference runner
+ network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+
+ # Specify model specific audio data requirements
+ # Overlap value specifies the number of samples to rewind between each data window
+ audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000,
+ mono=True)
+
+ # Create the preprocessor
+ mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
+ num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
+ mfcc = MFCC(mfcc_params)
+ preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320)
+
+ # Set threshold for displaying classification and commence stream or file processing
+ threshold = .90
+ if streaming_enabled:
+ # Initialise audio stream
+ record_stream = CaptureAudioStream(audio_capture_params)
+ record_stream.set_stream_defaults()
+ record_stream.set_recording_duration(args.duration)
+ record_stream.countdown()
+
+ with sd.InputStream(callback=record_stream.callback):
+ print("Recording audio. Please speak.")
+ while record_stream.is_active:
+
+ audio_data = record_stream.capture_data()
+ recognise_speech(audio_data, network, preprocessor, threshold)
+ record_stream.is_first_window = False
+ print("\nFinished recording.")
+
+ # If file path has been supplied read-in and run inference
+ else:
+ print("Processing Audio Frames...")
+ buffer = capture_audio(audio_file, audio_capture_params)
+ for audio_data in buffer:
+ recognise_speech(audio_data, network, preprocessor, threshold)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/python/pyarmnn/examples/object_detection/run_video_file.py b/python/pyarmnn/examples/object_detection/run_video_file.py
index e31b779458..52f19d2c15 100644
--- a/python/pyarmnn/examples/object_detection/run_video_file.py
+++ b/python/pyarmnn/examples/object_detection/run_video_file.py
@@ -1,83 +1,87 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-"""
-Object detection demo that takes a video file, runs inference on each frame producing
-bounding boxes and labels around detected objects, and saves the processed video.
-"""
-
-import os
-import sys
-script_dir = os.path.dirname(__file__)
-sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
-
-import cv2
-from tqdm import tqdm
-from argparse import ArgumentParser
-
-from ssd import ssd_processing, ssd_resize_factor
-from yolo import yolo_processing, yolo_resize_factor
-from utils import dict_labels
-from cv_utils import init_video_file_capture, preprocess, draw_bounding_boxes
-from network_executor import ArmnnNetworkExecutor
-
-
-def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
- """
- Gets model-specific information such as model labels and decoding and processing functions.
- The user can include their own network and functions by adding another statement.
-
- Args:
- model_name: Name of type of supported model.
- video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
-
- Returns:
- Model labels, decoding and processing functions.
- """
- if model_name == 'ssd_mobilenet_v1':
- return ssd_processing, ssd_resize_factor(video)
- elif model_name == 'yolo_v3_tiny':
- return yolo_processing, yolo_resize_factor(video, input_binding_info)
- else:
- raise ValueError(f'{model_name} is not a valid model name')
-
-
-def main(args):
- video, video_writer, frame_count = init_video_file_capture(args.video_file_path, args.output_video_file_path)
-
- executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
- process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
- labels = dict_labels(args.label_path, include_rgb=True)
-
- for _ in tqdm(frame_count, desc='Processing frames'):
- frame_present, frame = video.read()
- if not frame_present:
- continue
- input_tensors = preprocess(frame, executor.input_binding_info)
- output_result = executor.run(input_tensors)
- detections = process_output(output_result)
- draw_bounding_boxes(frame, detections, resize_factor, labels)
- video_writer.write(frame)
- print('Finished processing frames')
- video.release(), video_writer.release()
-
-
-if __name__ == '__main__':
- parser = ArgumentParser()
- parser.add_argument('--video_file_path', required=True, type=str,
- help='Path to the video file to run object detection on')
- parser.add_argument('--model_file_path', required=True, type=str,
- help='Path to the Object Detection model to use')
- parser.add_argument('--model_name', required=True, type=str,
- help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
- parser.add_argument('--label_path', required=True, type=str,
- help='Path to the labelset for the provided model file')
- parser.add_argument('--output_video_file_path', type=str,
- help='Path to the output video file with detections added in')
- parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
- help='Takes the preferred backends in preference order, separated by whitespace, '
- 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
- 'Defaults to [CpuAcc, CpuRef]')
- args = parser.parse_args()
- main(args)
+# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Object detection demo that takes a video file, runs inference on each frame producing
+bounding boxes and labels around detected objects, and saves the processed video.
+"""
+
+import os
+import sys
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+import cv2
+from tqdm import tqdm
+from argparse import ArgumentParser
+
+from ssd import ssd_processing, ssd_resize_factor
+from yolo import yolo_processing, yolo_resize_factor
+from utils import dict_labels
+from cv_utils import init_video_file_capture, preprocess, draw_bounding_boxes
+from network_executor import ArmnnNetworkExecutor
+
+
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+ """
+ Gets model-specific information such as model labels and decoding and processing functions.
+ The user can include their own network and functions by adding another statement.
+
+ Args:
+ model_name: Name of type of supported model.
+ video: Video capture object, contains information about data source.
+ input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+
+ Returns:
+ Model labels, decoding and processing functions.
+ """
+ if model_name == 'ssd_mobilenet_v1':
+ return ssd_processing, ssd_resize_factor(video)
+ elif model_name == 'yolo_v3_tiny':
+ return yolo_processing, yolo_resize_factor(video, input_binding_info)
+ else:
+ raise ValueError(f'{model_name} is not a valid model name')
+
+
+def main(args):
+ video, video_writer, frame_count = init_video_file_capture(args.video_file_path, args.output_video_file_path)
+
+ executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+ process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
+ labels = dict_labels(args.label_path, include_rgb=True)
+
+ for _ in tqdm(frame_count, desc='Processing frames'):
+ frame_present, frame = video.read()
+ if not frame_present:
+ continue
+ model_name = args.model_name
+ if model_name == "ssd_mobilenet_v1":
+ input_tensors = preprocess(frame, executor.input_binding_info, True)
+ else:
+ input_tensors = preprocess(frame, executor.input_binding_info, False)
+ output_result = executor.run(input_tensors)
+ detections = process_output(output_result)
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
+ video_writer.write(frame)
+ print('Finished processing frames')
+ video.release(), video_writer.release()
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--video_file_path', required=True, type=str,
+ help='Path to the video file to run object detection on')
+ parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+ parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+ parser.add_argument('--label_path', required=True, type=str,
+ help='Path to the labelset for the provided model file')
+ parser.add_argument('--output_video_file_path', type=str,
+ help='Path to the output video file with detections added in')
+ parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+ args = parser.parse_args()
+ main(args)
diff --git a/python/pyarmnn/examples/object_detection/run_video_stream.py b/python/pyarmnn/examples/object_detection/run_video_stream.py
index 8635a40a9e..dba615b97e 100644
--- a/python/pyarmnn/examples/object_detection/run_video_stream.py
+++ b/python/pyarmnn/examples/object_detection/run_video_stream.py
@@ -1,85 +1,90 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-"""
-Object detection demo that takes a video stream from a device, runs inference
-on each frame producing bounding boxes and labels around detected objects,
-and displays a window with the latest processed frame.
-"""
-
-import os
-import sys
-script_dir = os.path.dirname(__file__)
-sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
-
-import cv2
-from argparse import ArgumentParser
-
-from ssd import ssd_processing, ssd_resize_factor
-from yolo import yolo_processing, yolo_resize_factor
-from utils import dict_labels
-from cv_utils import init_video_stream_capture, preprocess, draw_bounding_boxes
-from network_executor import ArmnnNetworkExecutor
-
-
-def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
- """
- Gets model-specific information such as model labels and decoding and processing functions.
- The user can include their own network and functions by adding another statement.
-
- Args:
- model_name: Name of type of supported model.
- video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
-
- Returns:
- Model labels, decoding and processing functions.
- """
- if model_name == 'ssd_mobilenet_v1':
- return ssd_processing, ssd_resize_factor(video)
- elif model_name == 'yolo_v3_tiny':
- return yolo_processing, yolo_resize_factor(video, input_binding_info)
- else:
- raise ValueError(f'{model_name} is not a valid model name')
-
-
-def main(args):
- video = init_video_stream_capture(args.video_source)
- executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
-
- process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
- labels = dict_labels(args.label_path, include_rgb=True)
-
- while True:
- frame_present, frame = video.read()
- frame = cv2.flip(frame, 1) # Horizontally flip the frame
- if not frame_present:
- raise RuntimeError('Error reading frame from video stream')
- input_tensors = preprocess(frame, executor.input_binding_info)
- print("Running inference...")
- output_result = executor.run(input_tensors)
- detections = process_output(output_result)
- draw_bounding_boxes(frame, detections, resize_factor, labels)
- cv2.imshow('PyArmNN Object Detection Demo', frame)
- if cv2.waitKey(1) == 27:
- print('\nExit key activated. Closing video...')
- break
- video.release(), cv2.destroyAllWindows()
-
-
-if __name__ == '__main__':
- parser = ArgumentParser()
- parser.add_argument('--video_source', type=int, default=0,
- help='Device index to access video stream. Defaults to primary device camera at index 0')
- parser.add_argument('--model_file_path', required=True, type=str,
- help='Path to the Object Detection model to use')
- parser.add_argument('--model_name', required=True, type=str,
- help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
- parser.add_argument('--label_path', required=True, type=str,
- help='Path to the labelset for the provided model file')
- parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
- help='Takes the preferred backends in preference order, separated by whitespace, '
- 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
- 'Defaults to [CpuAcc, CpuRef]')
- args = parser.parse_args()
- main(args)
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Object detection demo that takes a video stream from a device, runs inference
+on each frame producing bounding boxes and labels around detected objects,
+and displays a window with the latest processed frame.
+"""
+
+import os
+import sys
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+import cv2
+from argparse import ArgumentParser
+
+from ssd import ssd_processing, ssd_resize_factor
+from yolo import yolo_processing, yolo_resize_factor
+from utils import dict_labels
+from cv_utils import init_video_stream_capture, preprocess, draw_bounding_boxes
+from network_executor import ArmnnNetworkExecutor
+
+
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+ """
+ Gets model-specific information such as model labels and decoding and processing functions.
+ The user can include their own network and functions by adding another statement.
+
+ Args:
+ model_name: Name of type of supported model.
+ video: Video capture object, contains information about data source.
+ input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+
+ Returns:
+ Model labels, decoding and processing functions.
+ """
+ if model_name == 'ssd_mobilenet_v1':
+ return ssd_processing, ssd_resize_factor(video)
+ elif model_name == 'yolo_v3_tiny':
+ return yolo_processing, yolo_resize_factor(video, input_binding_info)
+ else:
+ raise ValueError(f'{model_name} is not a valid model name')
+
+
+def main(args):
+ video = init_video_stream_capture(args.video_source)
+ executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+
+ model_name = args.model_name
+ process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
+ labels = dict_labels(args.label_path, include_rgb=True)
+
+ while True:
+ frame_present, frame = video.read()
+ frame = cv2.flip(frame, 1) # Horizontally flip the frame
+ if not frame_present:
+ raise RuntimeError('Error reading frame from video stream')
+
+ if model_name == "ssd_mobilenet_v1":
+ input_tensors = preprocess(frame, executor.input_binding_info, True)
+ else:
+ input_tensors = preprocess(frame, executor.input_binding_info, False)
+ print("Running inference...")
+ output_result = executor.run(input_tensors)
+ detections = process_output(output_result)
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
+ cv2.imshow('PyArmNN Object Detection Demo', frame)
+ if cv2.waitKey(1) == 27:
+ print('\nExit key activated. Closing video...')
+ break
+ video.release(), cv2.destroyAllWindows()
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--video_source', type=int, default=0,
+ help='Device index to access video stream. Defaults to primary device camera at index 0')
+ parser.add_argument('--model_file_path', required=True, type=str,
+ help='Path to the Object Detection model to use')
+ parser.add_argument('--model_name', required=True, type=str,
+ help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
+ parser.add_argument('--label_path', required=True, type=str,
+ help='Path to the labelset for the provided model file')
+ parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
+ help='Takes the preferred backends in preference order, separated by whitespace, '
+ 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
+ 'Defaults to [CpuAcc, CpuRef]')
+ args = parser.parse_args()
+ main(args)
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
index c4096efcc5..c39959bfbc 100644
--- a/python/pyarmnn/examples/speech_recognition/README.md
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -29,31 +29,31 @@ Install the PortAudio package:
$ sudo apt-get install libsndfile1 libportaudio2
```
-Install the required Python modules:
+Install the required Python modules:
```bash
$ pip install -r requirements.txt
```
### Model
+The model we are using is the [Wav2Letter](https://github.com/ARM-software/ML-zoo/tree/master/models/speech_recognition/wav2letter/tflite_int8 ) which can be found in the [Arm Model Zoo repository](
+https://github.com/ARM-software/ML-zoo/tree/master/models).
-The model for this can be found in the Arm Model Zoo repository:
-https://github.com/ARM-software/ML-zoo/tree/master/models
-
-The model we're looking for:
-https://github.com/ARM-software/ML-zoo/tree/master/models/speech_recognition/wav2letter/tflite_int8
+A small selection of suitable wav files containing human speech can be found [here](https://github.com/Azure-Samples/cognitive-services-speech-sdk/tree/master/sampledata/audiofiles).
+Labels for this model are defined within run_audio_file.py.
## Performing Automatic Speech Recognition
### Processing Audio Files
+Please ensure that your audio file has a sampling rate of 16000Hz.
+
To run ASR on an audio file, use the following command:
```bash
-$ python run_audio_file.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model> --labels_file_path <path/to/your_labels>
+$ python run_audio_file.py --audio_file_path <path/to/your_audio> --model_file_path <path/to/your_model>
```
-Please ensure that your audio file has a sampling rate of 16000Hz.
You may also add the optional flags:
@@ -79,15 +79,18 @@ You may also add the optional flags:
### Initialization
-The application parses the supplied user arguments and loads the audio file into the `AudioCapture` class, which initialises the audio source and sets sampling parameters required by the model with `ModelParams` class.
+The application parses the supplied user arguments and loads the audio file in chunks through the `capture_audio()` method which accepts sampling criteria as an `AudioCaptureParams` tuple.
-`AudioCapture` helps us to capture chunks of audio data from the source. With ASR from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size.
+With ASR from an audio file, the application will create a generator object to yield blocks of audio data from the file with a minimum sample size defined in AudioCaptureParams.
-To interpret the inference result of the loaded network, the application must load the labels that are associated with the model. The `dict_labels()` function creates a dictionary that is keyed on the classification index at the output node of the model. The values of the dictionary are the corresponding characters.
+MFCC features are extracted from each block based on criteria defined in the `MFCCParams` tuple.
+these extracted features constitute the input tensors for the model.
+
+To interpret the inference result of the loaded network; the application passes the label dictionary defined in run_audio_file.py to a decoder and displays the result.
### Creating a network
-A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite, TF, and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
+A PyArmNN application must import a graph from file using an appropriate parser. Arm NN provides parsers for various model file types, including TFLite and ONNX. These parsers are libraries for loading neural networks of various formats into the Arm NN runtime.
Arm NN supports optimized execution on multiple CPU, GPU, and Ethos-N devices. Before executing a graph, the application must select the appropriate device context by using `IRuntime()` to create a runtime context with default options. We can optimize the imported graph by specifying a list of backends in order of preference and implementing backend-specific optimizations, identified by a unique string, for example CpuAcc, GpuAcc, CpuRef represent the accelerated CPU and GPU backends and the CPU reference kernels respectively.
@@ -120,13 +123,17 @@ self.output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_n
```
### Automatic speech recognition pipeline
+Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) are extracted based on criteria defined in the MFCCParams tuple and associated`MFCC Class`.
+MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+
+The `MFCC` class is used in conjunction with the `AudioPreProcessor` class to extract and process MFCC features from a given audio frame.
+
-The `MFCC` class is used to extract the Mel-frequency Cepstral Coefficients (MFCCs, [see Wikipedia](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum)) from a given audio frame to be used as features for the network. MFCCs are the result of computing the dot product of the Discrete Cosine Transform (DCT) Matrix and the log of the Mel energy.
+After all the MFCCs needed for an inference have been extracted from the audio data we convolve them with 1-dimensional Savitzky-Golay filters to compute the first and second MFCC derivatives with respect to time. The MFCCs and the derivatives constitute the input tensors that will be classified by an `ArmnnNetworkExecutor`object.
-After all the MFCCs needed for an inference have been extracted from the audio data, we convolve them with 1-dimensional Savitzky-Golay filters to compute the first and second MFCC derivatives with respect to time. The MFCCs and the derivatives are concatenated to make the input tensor for the model.
```python
-# preprocess.py
+# mfcc.py & wav2lettermfcc.py
# Extract MFCC features
log_mel_energy = np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
mfcc_feats = np.dot(self.__dct_matrix, log_mel_energy)
@@ -165,4 +172,4 @@ Having now gained a solid understanding of performing automatic speech recogniti
An important step to improving accuracy of the generated output sentences is by providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
-In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce.
+In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce. \ No newline at end of file
diff --git a/python/pyarmnn/examples/speech_recognition/audio_capture.py b/python/pyarmnn/examples/speech_recognition/audio_capture.py
deleted file mode 100644
index 0c899208a4..0000000000
--- a/python/pyarmnn/examples/speech_recognition/audio_capture.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-"""Contains AudioCapture class for capturing chunks of audio data from file."""
-
-from typing import Generator
-
-import numpy as np
-import soundfile as sf
-
-
-class ModelParams:
- def __init__(self, model_file_path: str):
- """Defines sampling parameters for model used.
-
- Args:
- model_file_path: Path to ASR model to use.
- """
- self.path = model_file_path
- self.mono = True
- self.dtype = np.float32
- self.samplerate = 16000
- self.min_samples = 47712 # (model_input_size-1)*stride + frame_len
-
-
-class AudioCapture:
- def __init__(self, model_params):
- """Sampling parameters for model used."""
- self.model_params = model_params
-
- def from_audio_file(self, audio_file_path, overlap=31712) -> Generator[np.ndarray, None, None]:
- """Creates a generator that yields audio data from a file. Data is padded with
- zeros if necessary to make up minimum number of samples.
-
- Args:
- audio_file_path: Path to audio file provided by user.
- overlap: The overlap with previous buffer. We need the offset to be the same as the inner context
- of the mfcc output, which is sized as 100 x 39. Each mfcc compute produces 1 x 39 vector,
- and consumes 160 audio samples. The default overlap is then calculated to be 47712 - (160 x 100)
- where 47712 is the min_samples needed for 1 inference of wav2letter.
-
- Yields:
- Blocks of audio data of minimum sample size.
- """
- with sf.SoundFile(audio_file_path) as audio_file:
- for block in audio_file.blocks(
- blocksize=self.model_params.min_samples,
- dtype=self.model_params.dtype,
- always_2d=True,
- fill_value=0,
- overlap=overlap
- ):
- # Convert to mono if specified
- if self.model_params.mono and block.shape[0] > 1:
- block = np.mean(block, axis=1)
- yield block
diff --git a/python/pyarmnn/examples/speech_recognition/audio_utils.py b/python/pyarmnn/examples/speech_recognition/audio_utils.py
index f03d2e1290..1ac78e8074 100644
--- a/python/pyarmnn/examples/speech_recognition/audio_utils.py
+++ b/python/pyarmnn/examples/speech_recognition/audio_utils.py
@@ -1,10 +1,9 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Utilities for speech recognition apps."""
import numpy as np
-import pyarmnn as ann
def decode(model_output: np.ndarray, labels: dict) -> str:
@@ -50,33 +49,6 @@ def display_text(text: str):
print(text, sep="", end="", flush=True)
-def quantize_input(data, input_binding_info):
- """Quantize the float input to (u)int8 ready for inputting to model."""
- if data.ndim != 2:
- raise RuntimeError("Audio data must have 2 dimensions for quantization")
-
- quant_scale = input_binding_info[1].GetQuantizationScale()
- quant_offset = input_binding_info[1].GetQuantizationOffset()
- data_type = input_binding_info[1].GetDataType()
-
- if data_type == ann.DataType_QAsymmS8:
- data_type = np.int8
- elif data_type == ann.DataType_QAsymmU8:
- data_type = np.uint8
- else:
- raise ValueError("Could not quantize data to required data type")
-
- d_min = np.iinfo(data_type).min
- d_max = np.iinfo(data_type).max
-
- for row in range(data.shape[0]):
- for col in range(data.shape[1]):
- data[row, col] = (data[row, col] / quant_scale) + quant_offset
- data[row, col] = np.clip(data[row, col], d_min, d_max)
- data = data.astype(data_type)
- return data
-
-
def decode_text(is_first_window, labels, output_result):
"""
Slices the text appropriately depending on the window, and decodes for wav2letter output.
@@ -88,7 +60,6 @@ def decode_text(is_first_window, labels, output_result):
is_first_window: Boolean to show if it is the first window we are running inference on
labels: the label set
output_result: the output from the inference
- text: the current text string, to be displayed at the end
Returns:
current_r_context: the current right context
text: the current text string, with the latest output decoded and appended
@@ -109,25 +80,3 @@ def decode_text(is_first_window, labels, output_result):
# Store the right context, we will need it after the last inference
current_r_context = decode(output_result[0][0][0][right_context_start:], labels)
return current_r_context, text
-
-
-def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
- """
- Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
- input tensors.
-
- Args:
- audio_data: The audio data to process
- mfcc_instance: the mfcc class instance
- input_binding_info: the model input binding info
- mfcc_preprocessor: the mfcc preprocessor instance
- Returns:
- input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
- """
-
- data_type = input_binding_info[1].GetDataType()
- input_tensor = mfcc_preprocessor.extract_features(audio_data)
- if data_type != ann.DataType_Float32:
- input_tensor = quantize_input(input_tensor, input_binding_info)
- input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
- return input_tensors
diff --git a/python/pyarmnn/examples/speech_recognition/requirements.txt b/python/pyarmnn/examples/speech_recognition/requirements.txt
index 4b8f3e6d24..96782eafd0 100644
--- a/python/pyarmnn/examples/speech_recognition/requirements.txt
+++ b/python/pyarmnn/examples/speech_recognition/requirements.txt
@@ -1,2 +1,5 @@
numpy>=1.19.2
-soundfile>=0.10.3 \ No newline at end of file
+soundfile>=0.10.3
+pytest==6.2.4
+pytest-allclose==1.0.0
+sounddevice==0.4.2
diff --git a/python/pyarmnn/examples/speech_recognition/run_audio_file.py b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
index 942de2081c..0430f68c16 100644
--- a/python/pyarmnn/examples/speech_recognition/run_audio_file.py
+++ b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
@@ -1,20 +1,29 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Automatic speech recognition with PyArmNN demo for processing audio clips to text."""
import sys
import os
-from argparse import ArgumentParser
+import numpy as np
script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+from argparse import ArgumentParser
from network_executor import ArmnnNetworkExecutor
-from utils import dict_labels
-from preprocess import MFCCParams, Preprocessor, MFCC
-from audio_capture import AudioCapture, ModelParams
-from audio_utils import decode_text, prepare_input_tensors, display_text
+from utils import prepare_input_tensors
+from audio_capture import AudioCaptureParams, capture_audio
+from audio_utils import decode_text, display_text
+from wav2letter_mfcc import Wav2LetterMFCC, W2LAudioPreprocessor
+from mfcc import MFCCParams
+
+# Model Specific Labels
+labels = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm',
+ 13: 'n',
+ 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y',
+ 25: 'z',
+ 26: "'", 27: ' ', 28: '$'}
def parse_args():
@@ -32,12 +41,6 @@ def parse_args():
help="Path to ASR model to use",
)
parser.add_argument(
- "--labels_file_path",
- required=True,
- type=str,
- help="Path to text file containing labels to map to model output",
- )
- parser.add_argument(
"--preferred_backends",
type=str,
nargs="+",
@@ -52,22 +55,23 @@ def parse_args():
def main(args):
# Read command line args
audio_file = args.audio_file_path
- model = ModelParams(args.model_file_path)
- labels = dict_labels(args.labels_file_path)
# Create the ArmNN inference runner
- network = ArmnnNetworkExecutor(model.path, args.preferred_backends)
+ network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+
+ # Specify model specific audio data requirements
+ audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=31712, min_samples=47712, sampling_freq=16000,
+ mono=True)
+
+ buffer = capture_audio(audio_file, audio_capture_params)
- audio_capture = AudioCapture(model)
- buffer = audio_capture.from_audio_file(audio_file)
+ # Extract features and create the preprocessor
- # Create the preprocessor
mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=128, mel_lo_freq=0, mel_hi_freq=8000,
- num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_FFT=512)
- mfcc = MFCC(mfcc_params)
- preprocessor = Preprocessor(mfcc, model_input_size=296, stride=160)
+ num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_fft=512)
- text = ""
+ wmfcc = Wav2LetterMFCC(mfcc_params)
+ preprocessor = W2LAudioPreprocessor(wmfcc, model_input_size=296, stride=160)
current_r_context = ""
is_first_window = True
diff --git a/python/pyarmnn/examples/speech_recognition/tests/conftest.py b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
index 730c291cfa..151816e919 100644
--- a/python/pyarmnn/examples/speech_recognition/tests/conftest.py
+++ b/python/pyarmnn/examples/speech_recognition/tests/conftest.py
@@ -1,34 +1,24 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-import ntpath
-
-import urllib.request
-
-import pytest
-
-script_dir = os.path.dirname(__file__)
-
-@pytest.fixture(scope="session")
-def test_data_folder(request):
- """
- This fixture returns path to folder with shared test resources among all tests
- """
-
- data_dir = os.path.join(script_dir, "testdata")
-
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- files_to_download = ["https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master"
- "/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"]
-
- for file in files_to_download:
- path, filename = ntpath.split(file)
- file_path = os.path.join(script_dir, "testdata", filename)
- if not os.path.exists(file_path):
- print("\nDownloading test file: " + file_path + "\n")
- urllib.request.urlretrieve(file, file_path)
-
- return data_dir
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import ntpath
+
+import urllib.request
+
+import pytest
+
+script_dir = os.path.dirname(__file__)
+
+@pytest.fixture(scope="session")
+def test_data_folder(request):
+ """
+ This fixture returns path to folder with shared test resources among asr tests
+ """
+
+ data_dir = os.path.join(script_dir, "testdata")
+
+ if not os.path.exists(data_dir):
+ os.mkdir(data_dir)
+
+ return data_dir \ No newline at end of file
diff --git a/python/pyarmnn/examples/speech_recognition/tests/context.py b/python/pyarmnn/examples/speech_recognition/tests/context.py
deleted file mode 100644
index a810010e9f..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/context.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-import sys
-
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'common'))
-import utils as common_utils
-
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-import audio_capture
-import audio_utils
-import preprocess
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py b/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
deleted file mode 100644
index 281d0df587..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/test_audio_file.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-
-import numpy as np
-
-from context import audio_capture
-
-
-def test_audio_file(test_data_folder):
- audio_file = os.path.join(test_data_folder, "myVoiceIsMyPassportVerifyMe04.wav")
- capture = audio_capture.AudioCapture(audio_capture.ModelParams(""))
- buffer = capture.from_audio_file(audio_file)
- audio_data = next(buffer)
- assert audio_data.shape == (47712,)
- assert audio_data.dtype == np.float32
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
index 1db71a47b8..14db7f2064 100644
--- a/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
+++ b/python/pyarmnn/examples/speech_recognition/tests/test_decoder.py
@@ -5,13 +5,16 @@ import os
import numpy as np
-from context import common_utils
from context import audio_utils
+labels = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm',
+ 13: 'n',
+ 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y',
+ 25: 'z',
+ 26: "'", 27: ' ', 28: '$'}
+
def test_labels(test_data_folder):
- labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
- labels = common_utils.dict_labels(labels_file)
assert len(labels) == 29
assert labels[26] == "\'"
assert labels[27] == r" "
@@ -19,10 +22,8 @@ def test_labels(test_data_folder):
def test_decoder(test_data_folder):
- labels_file = os.path.join(test_data_folder, "wav2letter_labels.txt")
- labels = common_utils.dict_labels(labels_file)
- output_tensor = os.path.join(test_data_folder, "inf_out.npy")
+ output_tensor = os.path.join(test_data_folder, "inference_output.npy")
encoded = np.load(output_tensor)
decoded_text = audio_utils.decode(encoded, labels)
- assert decoded_text == "and they walkd immediately out of the apartiment by anothe"
+ assert decoded_text == "my voice is my pass"
diff --git a/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py b/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
deleted file mode 100644
index d692ab51c8..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/test_mfcc.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import numpy as np
-
-from context import preprocess
-
-test_wav = [
- -3,0,1,-1,2,3,-2,2,
- 1,-2,0,3,-1,8,3,2,
- -1,-1,2,7,3,5,6,6,
- 6,12,5,6,3,3,5,4,
- 4,6,7,7,7,3,7,2,
- 8,4,4,2,-4,-1,-1,-4,
- 2,1,-1,-4,0,-7,-6,-2,
- -5,1,-5,-1,-7,-3,-3,-7,
- 0,-3,3,-5,0,1,-2,-2,
- -3,-3,-7,-3,-2,-6,-5,-8,
- -2,-8,4,-9,-4,-9,-5,-5,
- -3,-9,-3,-9,-1,-7,-4,1,
- -3,2,-8,-4,-4,-5,1,-3,
- -1,0,-1,-2,-3,-2,-4,-1,
- 1,-1,3,0,3,2,0,0,
- 0,-3,1,1,0,8,3,4,
- 1,5,6,4,7,3,3,0,
- 3,6,7,6,4,5,9,9,
- 5,5,8,1,6,9,6,6,
- 7,1,8,1,5,0,5,5,
- 0,3,2,7,2,-3,3,0,
- 3,0,0,0,2,0,-1,-1,
- -2,-3,-8,0,1,0,-3,-3,
- -3,-2,-3,-3,-4,-6,-2,-8,
- -9,-4,-1,-5,-3,-3,-4,-3,
- -6,3,0,-1,-2,-9,-4,-2,
- 2,-1,3,-5,-5,-2,0,-2,
- 0,-1,-3,1,-2,9,4,5,
- 2,2,1,0,-6,-2,0,0,
- 0,-1,4,-4,3,-7,-1,5,
- -6,-1,-5,4,3,9,-2,1,
- 3,0,0,-2,1,2,1,1,
- 0,3,2,-1,3,-3,7,0,
- 0,3,2,2,-2,3,-2,2,
- -3,4,-1,-1,-5,-1,-3,-2,
- 1,-1,3,2,4,1,2,-2,
- 0,2,7,0,8,-3,6,-3,
- 6,1,2,-3,-1,-1,-1,1,
- -2,2,1,2,0,-2,3,-2,
- 3,-2,1,0,-3,-1,-2,-4,
- -6,-5,-8,-1,-4,0,-3,-1,
- -1,-1,0,-2,-3,-7,-1,0,
- 1,5,0,5,1,1,-3,0,
- -6,3,-8,4,-8,6,-6,1,
- -6,-2,-5,-6,0,-5,4,-1,
- 4,-2,1,2,1,0,-2,0,
- 0,2,-2,2,-5,2,0,-2,
- 1,-2,0,5,1,0,1,5,
- 0,8,3,2,2,0,5,-2,
- 3,1,0,1,0,-2,-1,-3,
- 1,-1,3,0,3,0,-2,-1,
- -4,-4,-4,-1,-4,-4,-3,-6,
- -3,-7,-3,-1,-2,0,-5,-4,
- -7,-3,-2,-2,1,2,2,8,
- 5,4,2,4,3,5,0,3,
- 3,6,4,2,2,-2,4,-2,
- 3,3,2,1,1,4,-5,2,
- -3,0,-1,1,-2,2,5,1,
- 4,2,3,1,-1,1,0,6,
- 0,-2,-1,1,-1,2,-5,-1,
- -5,-1,-6,-3,-3,2,4,0,
- -1,-5,3,-4,-1,-3,-4,1,
- -4,1,-1,-1,0,-5,-4,-2,
- -1,-1,-3,-7,-3,-3,4,4,
-]
-
-def test_mel_scale_function_with_htk_true():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
-
- mel = mfcc_inst.mel_scale(16, True)
-
- assert np.isclose(mel, 25.470010570730597)
-
-
-def test_mel_scale_function_with_htk_false():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
-
- mel = mfcc_inst.mel_scale(16, False)
-
- assert np.isclose(mel, 0.24)
-
-
-def test_inverse_mel_scale_function_with_htk_true():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
-
- mel = mfcc_inst.inv_mel_scale(16, True)
-
- assert np.isclose(mel, 10.008767240008943)
-
-
-def test_inverse_mel_scale_function_with_htk_false():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
-
- mel = mfcc_inst.inv_mel_scale(16, False)
-
- assert np.isclose(mel, 1071.170287494467)
-
-
-def test_create_mel_filter_bank():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
-
- mel_filter_bank = mfcc_inst.create_mel_filter_bank()
-
- assert len(mel_filter_bank) == 128
-
- assert str(mel_filter_bank[0]) == "[0.02837754]"
- assert str(mel_filter_bank[1]) == "[0.01438901 0.01398853]"
- assert str(mel_filter_bank[2]) == "[0.02877802]"
- assert str(mel_filter_bank[3]) == "[0.04236608]"
- assert str(mel_filter_bank[4]) == "[0.00040047 0.02797707]"
- assert str(mel_filter_bank[5]) == "[0.01478948 0.01358806]"
- assert str(mel_filter_bank[50]) == "[0.03298853]"
- assert str(mel_filter_bank[100]) == "[0.00260166 0.00588759 0.00914814 0.00798015 0.00476919 0.00158245]"
-
-
-def test_mfcc_compute():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- audio_data = np.array(test_wav) / (2 ** 15)
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
- mfcc_inst = preprocess.MFCC(mfcc_params)
- mfcc_feats = mfcc_inst.mfcc_compute(audio_data)
-
- assert np.isclose((mfcc_feats[0]), -834.9656973095651)
- assert np.isclose((mfcc_feats[1]), 21.026915475076322)
- assert np.isclose((mfcc_feats[2]), 18.628541708201688)
- assert np.isclose((mfcc_feats[3]), 7.341153529494758)
- assert np.isclose((mfcc_feats[4]), 18.907974386153214)
- assert np.isclose((mfcc_feats[5]), -5.360387487466194)
- assert np.isclose((mfcc_feats[6]), 6.523572638527085)
- assert np.isclose((mfcc_feats[7]), -11.270643644983316)
- assert np.isclose((mfcc_feats[8]), 8.375177203773777)
- assert np.isclose((mfcc_feats[9]), 12.06721844362991)
- assert np.isclose((mfcc_feats[10]), 8.30815892468875)
- assert np.isclose((mfcc_feats[11]), -13.499911910889917)
- assert np.isclose((mfcc_feats[12]), -18.176121251436165)
-
-
-def test_sliding_window_for_small_num_samples():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- mode_input_size = 9
- stride = 160
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- audio_data = np.array(test_wav) / (2 ** 15)
-
- full_audio_data = np.tile(audio_data, 9)
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
- mfcc_inst = preprocess.MFCC(mfcc_params)
- preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
-
- input_tensor = preprocessor.extract_features(full_audio_data)
-
- assert np.isclose(input_tensor[0][0], -3.4660944830426454)
- assert np.isclose(input_tensor[0][1], 0.3587718932127629)
- assert np.isclose(input_tensor[0][2], 0.3480551325669172)
- assert np.isclose(input_tensor[0][3], 0.2976191917228921)
- assert np.isclose(input_tensor[0][4], 0.3493037340849936)
- assert np.isclose(input_tensor[0][5], 0.2408643285767937)
- assert np.isclose(input_tensor[0][6], 0.2939659585037282)
- assert np.isclose(input_tensor[0][7], 0.2144552669573928)
- assert np.isclose(input_tensor[0][8], 0.302239565899944)
- assert np.isclose(input_tensor[0][9], 0.3187368787077345)
- assert np.isclose(input_tensor[0][10], 0.3019401051295793)
- assert np.isclose(input_tensor[0][11], 0.20449412797602678)
-
- assert np.isclose(input_tensor[0][38], -0.18751440767749533)
-
-
-def test_sliding_window_for_wav_2_letter_sized_input():
- samp_freq = 16000
- frame_len_ms = 32
- frame_len_samples = samp_freq * frame_len_ms * 0.001
- num_mfcc_feats = 13
- mode_input_size = 296
- stride = 160
- num_fbank_bins = 128
- mel_lo_freq = 0
- mil_hi_freq = 8000
- use_htk = False
- n_FFT = 512
-
- audio_data = np.zeros(47712, dtype=int)
-
- mfcc_params = preprocess.MFCCParams(samp_freq, num_fbank_bins, mel_lo_freq, mil_hi_freq, num_mfcc_feats,
- frame_len_samples, use_htk, n_FFT)
-
- mfcc_inst = preprocess.MFCC(mfcc_params)
- preprocessor = preprocess.Preprocessor(mfcc_inst, mode_input_size, stride)
-
- input_tensor = preprocessor.extract_features(audio_data)
-
- assert len(input_tensor[0]) == 39
- assert len(input_tensor) == 296
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy b/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
deleted file mode 100644
index a6f9ec0c70..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/testdata/inf_out.npy
+++ /dev/null
Binary files differ
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/inference_output.npy b/python/pyarmnn/examples/speech_recognition/tests/testdata/inference_output.npy
new file mode 100644
index 0000000000..88c42e0b70
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/tests/testdata/inference_output.npy
Binary files differ
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/quick_brown_fox_16000khz.wav b/python/pyarmnn/examples/speech_recognition/tests/testdata/quick_brown_fox_16000khz.wav
deleted file mode 100644
index 761c36062e..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/testdata/quick_brown_fox_16000khz.wav
+++ /dev/null
Binary files differ
diff --git a/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt b/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
deleted file mode 100644
index d7485b7da2..0000000000
--- a/python/pyarmnn/examples/speech_recognition/tests/testdata/wav2letter_labels.txt
+++ /dev/null
@@ -1,29 +0,0 @@
-a
-b
-c
-d
-e
-f
-g
-h
-i
-j
-k
-l
-m
-n
-o
-p
-q
-r
-s
-t
-u
-v
-w
-x
-y
-z
-'
-
-$ \ No newline at end of file
diff --git a/python/pyarmnn/examples/speech_recognition/wav2letter_mfcc.py b/python/pyarmnn/examples/speech_recognition/wav2letter_mfcc.py
new file mode 100644
index 0000000000..1cac24d588
--- /dev/null
+++ b/python/pyarmnn/examples/speech_recognition/wav2letter_mfcc.py
@@ -0,0 +1,91 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+import os
+import sys
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
+
+from mfcc import MFCC, AudioPreprocessor
+
+
+class Wav2LetterMFCC(MFCC):
+ """Extends base MFCC class to provide Wav2Letter-specific MFCC requirements."""
+
+ def __init__(self, mfcc_params):
+ super().__init__(mfcc_params)
+
+ def spectrum_calc(self, audio_data):
+ return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data,
+ self.mfcc_params.n_fft)) ** 2
+
+ def log_mel(self, mel_energy):
+ mel_energy += 1e-10
+ log_mel_energy = 10.0 * np.log10(mel_energy)
+ top_db = 80.0
+ return np.maximum(log_mel_energy, log_mel_energy.max() - top_db)
+
+ def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
+ """
+ Creates the Discrete Cosine Transform matrix to be used in the compute function.
+
+ Args:
+ num_fbank_bins: The number of filter bank bins
+ num_mfcc_feats: the number of MFCC features
+
+ Returns:
+ the DCT matrix
+ """
+ dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
+ for k in range(num_mfcc_feats):
+ for n in range(num_fbank_bins):
+ if k == 0:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (4 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+ else:
+ dct_m[(k * num_fbank_bins) + n] = 2 * np.sqrt(1 / (2 * num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+
+ dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
+ return dct_m
+
+ def mel_norm(self, weight, right_mel, left_mel):
+ """Over-riding parent class with ASR specific weight normalisation."""
+ enorm = 2.0 / (self.inv_mel_scale(right_mel, False) - self.inv_mel_scale(left_mel, False))
+ return weight * enorm
+
+
+class W2LAudioPreprocessor(AudioPreprocessor):
+
+ def __init__(self, mfcc, model_input_size, stride):
+ self.model_input_size = model_input_size
+ self.stride = stride
+
+ super().__init__(self, model_input_size, stride)
+ # Savitzky - Golay differential filters
+ self.savgol_order1_coeffs = np.array([6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
+ 1.66666667e-02, -3.46944695e-18, -1.66666667e-02,
+ -3.33333333e-02, -5.00000000e-02, -6.66666667e-02])
+
+ self.savgol_order2_coeffs = np.array([0.06060606, 0.01515152, -0.01731602,
+ -0.03679654, -0.04329004, -0.03679654,
+ -0.01731602, 0.01515152, 0.06060606])
+ self._mfcc_calc = mfcc
+
+ def mfcc_delta_calc(self, features):
+ """Over-riding parent class with ASR specific MFCC derivative features"""
+ mfcc_delta_np = np.zeros_like(features)
+ mfcc_delta2_np = np.zeros_like(features)
+
+ for i in range(features.shape[1]):
+ idelta = np.convolve(features[:, i], self.savgol_order1_coeffs, 'same')
+ mfcc_delta_np[:, i] = idelta
+ ideltadelta = np.convolve(features[:, i], self.savgol_order2_coeffs, 'same')
+ mfcc_delta2_np[:, i] = ideltadelta
+
+ features = np.concatenate((self._normalize(features), self._normalize(mfcc_delta_np),
+ self._normalize(mfcc_delta2_np)), axis=1)
+
+ return features
diff --git a/python/pyarmnn/examples/common/tests/conftest.py b/python/pyarmnn/examples/tests/conftest.py
index 5e027a0125..b7fa73b852 100644
--- a/python/pyarmnn/examples/common/tests/conftest.py
+++ b/python/pyarmnn/examples/tests/conftest.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
import os
@@ -6,12 +6,13 @@ import ntpath
import urllib.request
import zipfile
-
import pytest
script_dir = os.path.dirname(__file__)
+
+
@pytest.fixture(scope="session")
-def test_data_folder(request):
+def test_data_folder():
"""
This fixture returns path to folder with shared test resources among all tests
"""
@@ -19,11 +20,12 @@ def test_data_folder(request):
data_dir = os.path.join(script_dir, "testdata")
if not os.path.exists(data_dir):
os.mkdir(data_dir)
-
files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"
+ "https://github.com/ARM-software/ML-zoo/raw/master/models/object_detection/ssd_mobilenet_v1/tflite_uint8/ssd_mobilenet_v1.tflite",
+ "https://git.mlplatform.org/ml/ethos-u/ml-embedded-evaluation-kit.git/plain/resources/kws/samples/yes.wav",
+ "https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"
]
for file in files_to_download:
@@ -33,8 +35,5 @@ def test_data_folder(request):
print("\nDownloading test file: " + file_path + "\n")
urllib.request.urlretrieve(file, file_path)
- # Any unzipping needed, and moving around of files
- with zipfile.ZipFile(os.path.join(data_dir, "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"), 'r') as zip_ref:
- zip_ref.extractall(data_dir)
return data_dir
diff --git a/python/pyarmnn/examples/tests/context.py b/python/pyarmnn/examples/tests/context.py
new file mode 100644
index 0000000000..a678f94178
--- /dev/null
+++ b/python/pyarmnn/examples/tests/context.py
@@ -0,0 +1,22 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import sys
+import numpy as np
+
+script_dir = os.path.dirname(__file__)
+sys.path.insert(0, os.path.join(script_dir, '..'))
+
+import common.cv_utils as cv_utils
+import common.network_executor as network_executor
+import common.utils as utils
+import common.audio_capture as audio_capture
+import common.mfcc as mfcc
+
+import speech_recognition.wav2letter_mfcc as wav2letter_mfcc
+import speech_recognition.audio_utils as audio_utils
+
+
+
+
diff --git a/python/pyarmnn/examples/common/tests/test_utils.py b/python/pyarmnn/examples/tests/test_common_utils.py
index 28d68ea235..28d68ea235 100644
--- a/python/pyarmnn/examples/common/tests/test_utils.py
+++ b/python/pyarmnn/examples/tests/test_common_utils.py
diff --git a/python/pyarmnn/examples/tests/test_mfcc.py b/python/pyarmnn/examples/tests/test_mfcc.py
new file mode 100644
index 0000000000..2e806389e2
--- /dev/null
+++ b/python/pyarmnn/examples/tests/test_mfcc.py
@@ -0,0 +1,247 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import numpy as np
+import pytest
+import collections
+
+from context import mfcc
+from context import wav2letter_mfcc
+from context import audio_capture
+
+# Elements relevant to MFCC filter bank & feature extraction
+MFCC_TEST_PARAMS = collections.namedtuple('mfcc_test_params',
+ ['algo_params', 'mfcc_constructor', 'audio_proc_constructor'])
+
+
+def kws_test_params():
+ kws_algo_params = mfcc.MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
+ num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
+ return MFCC_TEST_PARAMS(kws_algo_params, mfcc.MFCC, mfcc.AudioPreprocessor)
+
+
+def asr_test_params():
+ asr_algo_params = mfcc.MFCCParams(sampling_freq=16000, num_fbank_bins=128, mel_lo_freq=0, mel_hi_freq=8000,
+ num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_fft=512)
+ return MFCC_TEST_PARAMS(asr_algo_params, wav2letter_mfcc.Wav2LetterMFCC, wav2letter_mfcc.W2LAudioPreprocessor)
+
+
+def kws_cap_params():
+ return audio_capture.AudioCaptureParams(dtype=np.float32, overlap=0, min_samples=16000, sampling_freq=16000,
+ mono=True)
+
+
+def asr_cap_params():
+ return audio_capture.AudioCaptureParams(dtype=np.float32, overlap=31712, min_samples=47712,
+ sampling_freq=16000, mono=True)
+
+
+@pytest.fixture()
+def audio_data(test_data_folder, file, audio_cap_params):
+ audio_file = os.path.join(test_data_folder, file)
+ capture = audio_capture.capture_audio(audio_file, audio_cap_params)
+ yield next(capture)
+
+
+@pytest.mark.parametrize("file", ["yes.wav", "myVoiceIsMyPassportVerifyMe04.wav"])
+@pytest.mark.parametrize("audio_cap_params", [kws_cap_params(), asr_cap_params()])
+def test_audio_file(audio_data, test_data_folder, file, audio_cap_params):
+ assert audio_data.shape == (audio_cap_params.min_samples,)
+ assert audio_data.dtype == audio_cap_params.dtype
+
+
+@pytest.mark.parametrize("mfcc_test_params, test_out", [(kws_test_params(), 25.470010570730597),
+ (asr_test_params(), 0.24)])
+def test_mel_scale_function(mfcc_test_params, test_out):
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+ mel = mfcc_inst.mel_scale(16, mfcc_test_params.algo_params.use_htk_method)
+ assert np.isclose(mel, test_out)
+
+
+@pytest.mark.parametrize("mfcc_test_params, test_out", [(kws_test_params(), 10.008767240008943),
+ (asr_test_params(), 1071.170287494467)])
+def test_inverse_mel_scale_function(mfcc_test_params, test_out):
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+ mel = mfcc_inst.inv_mel_scale(16, mfcc_test_params.algo_params.use_htk_method)
+ assert np.isclose(mel, test_out)
+
+
+mel_filter_test_data_kws = {0: [0.33883214, 0.80088392, 0.74663128, 0.30332531],
+ 1: [0.25336872, 0.69667469, 0.86883317, 0.44281119, 0.02493546],
+ 2: [0.13116683, 0.55718881, 0.97506454, 0.61490026, 0.21241678],
+ 5: [0.32725038, 0.69579596, 0.9417706, 0.58524989, 0.23445207],
+ -1: [0.02433275, 0.10371618, 0.1828123, 0.26162319, 0.34015089, 0.41839743,
+ 0.49636481, 0.57405503, 0.65147004, 0.72861179, 0.8054822, 0.88208318,
+ 0.95841659, 0.96551568, 0.88971181, 0.81416996, 0.73888833, 0.66386514,
+ 0.58909861, 0.514587, 0.44032856, 0.3663216, 0.29256441, 0.21905531,
+ 0.14579264, 0.07277474]}
+
+mel_filter_test_data_asr = {0: [0.02837754],
+ 1: [0.01438901, 0.01398853],
+ 2: [0.02877802],
+ 5: [0.01478948, 0.01358806],
+ -1: [4.82151203e-05, 9.48791110e-04, 1.84569875e-03, 2.73896782e-03,
+ 3.62862771e-03, 4.51470746e-03, 5.22215439e-03, 4.34314914e-03,
+ 3.46763895e-03, 2.59559614e-03, 1.72699334e-03, 8.61803536e-04]}
+
+
+@pytest.mark.parametrize("mfcc_test_params, test_out",
+ [(kws_test_params(), mel_filter_test_data_kws),
+ (asr_test_params(), mel_filter_test_data_asr)])
+def test_create_mel_filter_bank(mfcc_test_params, test_out):
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+ mel_filter_bank = mfcc_inst.create_mel_filter_bank()
+ assert len(mel_filter_bank) == mfcc_test_params.algo_params.num_fbank_bins
+ for indx, data in test_out.items():
+ assert np.allclose(mel_filter_bank[indx], data)
+
+
+mfcc_test_data_kws = (-22.671347398982626, -0.6161543999707211, 2.072326974167832,
+ 0.5813741475362223, 1.0165529747334272, 0.8581560719988703,
+ 0.4603911069624896, 0.03392820944377398, 1.1651093266902361,
+ 0.007200025869960908)
+
+mfcc_test_data_asr = (-735.46345398, 69.50331943, 16.39159347, 22.74874819, 24.84782893,
+ 10.67559303, 12.82828618, -3.51084271, 4.66633677, 10.20079095, 11.34782948, 3.90499354,
+ 9.32322384)
+
+
+@pytest.mark.parametrize("mfcc_test_params, test_out, file, audio_cap_params",
+ [(kws_test_params(), mfcc_test_data_kws, "yes.wav", kws_cap_params()),
+ (asr_test_params(), mfcc_test_data_asr, "myVoiceIsMyPassportVerifyMe04.wav",
+ asr_cap_params())])
+def test_mfcc_compute_first_frame(audio_data, mfcc_test_params, test_out, file, audio_cap_params):
+ audio_data = np.array(audio_data)[0:mfcc_test_params.algo_params.frame_len]
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+ mfcc_feats = mfcc_inst.mfcc_compute(audio_data)
+ assert np.allclose((mfcc_feats[0:mfcc_test_params.algo_params.num_mfcc_feats]), test_out)
+
+
+extract_test_data_kws = {0: [-2.2671347e+01, -6.1615437e-01, 2.0723269e+00, 5.8137417e-01,
+ 1.0165529e+00, 8.5815609e-01, 4.6039110e-01, 3.3928208e-02,
+ 1.1651093e+00, 7.2000260e-03],
+ 1: [-23.488806, -1.1687667, 3.0548365, 1.5129884, 1.4142203,
+ 0.6869772, 1.1875846, 0.5743369, 1.202258, -0.12133602],
+ 2: [-23.909292, -1.5186096, 1.8721082, 0.7378916, 0.44974303,
+ 0.17609395, 0.5183161, 0.37109664, 0.14186797, 0.58400506],
+ -1: [-23.752186, -0.1796912, 1.9514247, 0.32554424, 1.8425112,
+ 0.8763608, 0.78326845, 0.27808753, 0.73788685, 0.30338883]}
+
+extract_test_data_asr = {0: [-4.98830318e+00, 6.86444461e-01, 3.12024504e-01, 3.56840312e-01,
+ 3.71638149e-01, 2.71728605e-01, 2.86904365e-01, 1.71718955e-01,
+ 2.29365349e-01, 2.68381387e-01, 2.76467651e-01, 2.23998129e-01,
+ 2.62194842e-01, -1.48247385e+01, 1.21875501e+00, 4.20235842e-01,
+ 5.39400637e-01, 6.09882712e-01, 1.68513224e-01, 3.75330061e-01,
+ 8.57576132e-02, 1.92831963e-01, 1.41814977e-01, 1.57615796e-01,
+ 7.19076321e-02, 1.98729336e-01, 3.92199278e+00, -5.76856315e-01,
+ 1.17938723e-02, -9.25096497e-02, -3.59488949e-02, 1.13284402e-03,
+ 1.51282102e-01, 1.13404110e-01, -8.69824737e-02, -1.48449212e-01,
+ -1.24230251e-01, -1.90728232e-01, -5.37525006e-02],
+ 1: [-4.96694946e+00, 6.69411421e-01, 2.86189795e-01, 3.65071595e-01,
+ 3.92671198e-01, 2.44258150e-01, 2.52177566e-01, 2.16024980e-01,
+ 2.79812217e-01, 2.79687315e-01, 2.95228422e-01, 2.83991724e-01,
+ 2.46358261e-01, -1.33618221e+01, 1.08920455e+00, 3.88707787e-01,
+ 5.05674303e-01, 6.08285785e-01, 1.68113053e-01, 3.54529470e-01,
+ 6.68609440e-02, 1.52882755e-01, 6.89579248e-02, 1.18375972e-01,
+ 5.86742274e-02, 1.15678251e-01, 1.07892036e+01, -1.07193100e+00,
+ -2.18140319e-01, -3.35950345e-01, -2.57241666e-01, -5.54431602e-02,
+ -8.38544443e-02, -5.79114584e-03, -2.23973781e-01, -2.91451365e-01,
+ -2.11069033e-01, -1.90297231e-01, -2.76504964e-01],
+ 2: [-4.98664522e+00, 6.54802263e-01, 3.70355755e-01, 4.06837821e-01,
+ 4.05175537e-01, 2.29149669e-01, 2.83312678e-01, 2.17573136e-01,
+ 3.07824671e-01, 2.48388007e-01, 2.25399241e-01, 2.52003014e-01,
+ 2.83968121e-01, -1.05043650e+01, 7.91533887e-01, 3.11546475e-01,
+ 4.36079264e-01, 5.93271911e-01, 2.02480286e-01, 3.24254721e-01,
+ 6.29674867e-02, 9.67641100e-02, -1.62826646e-02, 5.47595806e-02,
+ 2.90475693e-02, 2.62522381e-02, 1.38787737e+01, -1.32597208e+00,
+ -3.73900205e-01, -4.38065380e-01, -3.05983245e-01, 1.14390980e-02,
+ -2.10821658e-01, -6.22789040e-02, -2.88273603e-01, -3.29794526e-01,
+ -2.43764088e-01, -1.70954674e-01, -3.65193188e-01],
+ -1: [-2.1894817, 1.583355, -0.45024827, 0.11657667, 0.08940444, 0.09041209,
+ 0.2003613, 0.11800499, 0.18838657, 0.29271516, 0.22758003, 0.10634928,
+ -0.04019014, 7.203311, -2.414309, 0.28750962, -0.24222863, 0.04680864,
+ -0.12129474, 0.18059334, 0.06250379, 0.11363743, -0.2561094, -0.08132717,
+ -0.08500769, 0.18916495, 1.3529671, -3.7919693, 1.937804, 0.6845761,
+ 0.15381853, 0.41106734, -0.28207013, 0.2195526, 0.06716935, -0.02886542,
+ -0.22860551, 0.24788341, 0.63940096]}
+
+
+@pytest.mark.parametrize("mfcc_test_params, model_input_size, stride, min_samples, file, audio_cap_params, test_out",
+ [(kws_test_params(), 49, 320, 16000, "yes.wav", kws_cap_params(),
+ extract_test_data_kws),
+ (asr_test_params(), 296, 160, 47712, "myVoiceIsMyPassportVerifyMe04.wav", asr_cap_params(),
+ extract_test_data_asr)])
+def test_feat_extraction_full_sized_input(audio_data,
+ mfcc_test_params,
+ model_input_size,
+ stride,
+ min_samples, file, audio_cap_params,
+ test_out):
+ """
+ Test out values were gathered by printing the mfcc features collected during the first full inference
+ on the test wav files. Note the extract_features() function simply calls the mfcc_compute() from previous
+ test but feeds in enough samples for an inference rather than a single frame. It also computes the 1st & 2nd
+ derivative features hence the shape (13*3 = 39).
+ Specific model_input_size and stride parameters are also required as additional arguments.
+ """
+ audio_data = np.array(audio_data)
+ # Pad with zeros to ensure min_samples for inference
+ audio_data.resize(min_samples)
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+ preprocessor = mfcc_test_params.audio_proc_constructor(mfcc_inst, model_input_size, stride)
+ # extract_features passes the audio data to mfcc_compute frame by frame and concatenates results
+ input_tensor = preprocessor.extract_features(audio_data)
+ assert len(input_tensor) == model_input_size
+ for indx, data in test_out.items():
+ assert np.allclose(input_tensor[indx], data)
+
+
+# Expected contents of input tensors for inference on a silent wav file
+extract_features_zeros_kws = {0: [-2.05949466e+02, -4.88498131e-15, 8.15428020e-15, -5.77315973e-15,
+ 7.03142511e-15, -1.11022302e-14, 2.18015108e-14, -1.77635684e-15,
+ 1.06581410e-14, 2.75335310e-14],
+ -1: [-2.05949466e+02, -4.88498131e-15, 8.15428020e-15, -5.77315973e-15,
+ 7.03142511e-15, -1.11022302e-14, 2.18015108e-14, -1.77635684e-15,
+ 1.06581410e-14, 2.75335310e-14]}
+
+extract_features_zeros_asr = {
+ 0: [-3.46410162e+00, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.79662980e+01, 1.75638694e-15, -9.41313626e-16,
+ 9.66012817e-16, -1.23221521e-15, 1.75638694e-15, -1.59035349e-15,
+ 2.41503204e-15, -1.64798493e-15, 4.39096735e-16, -4.95356004e-16,
+ -2.19548368e-16, -3.55668355e-15, 8.19843971e+00, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02],
+ - 1: [-3.46410162e+00, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
+ 2.88675135e-01, 2.79662980e+01, 1.75638694e-15, -9.41313626e-16,
+ 9.66012817e-16, -1.23221521e-15, 1.75638694e-15, -1.59035349e-15,
+ 2.41503204e-15, -1.64798493e-15, 4.39096735e-16, -4.95356004e-16,
+ -2.19548368e-16, -3.55668355e-15, 8.19843971e+00, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
+ -4.28340672e-02, -4.28340672e-02, -4.28340672e-02]}
+
+
+@pytest.mark.parametrize("mfcc_test_params,model_input_size, stride, min_samples, test_out",
+ [(kws_test_params(), 49, 320, 16000, extract_features_zeros_kws),
+ (asr_test_params(), 296, 160, 47712, extract_features_zeros_asr)])
+def test_feat_extraction_full_sized_input_zeros(mfcc_test_params, model_input_size, stride, min_samples, test_out):
+ audio_data = np.zeros(min_samples).astype(np.float32)
+ mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
+
+ preprocessor = mfcc_test_params.audio_proc_constructor(mfcc_inst, model_input_size,
+ stride)
+ input_tensor = preprocessor.extract_features(audio_data)
+ assert len(input_tensor) == model_input_size
+ for indx, data in test_out.items():
+ # Element 14 of feature extraction vector differs minutely during
+ # inference on a silent wav file compared to array of 0's
+ # Workarounds were to skip this sample or add large tolerance argument (atol=10)
+ assert np.allclose(input_tensor[indx][0:13], data[0:13])
+ assert np.allclose(input_tensor[indx][15:], data[15:])
diff --git a/python/pyarmnn/examples/common/tests/test_network_executor.py b/python/pyarmnn/examples/tests/test_network_executor.py
index e27b382078..c124b11382 100644
--- a/python/pyarmnn/examples/common/tests/test_network_executor.py
+++ b/python/pyarmnn/examples/tests/test_network_executor.py
@@ -10,12 +10,12 @@ from context import cv_utils
def test_execute_network(test_data_folder):
- model_path = os.path.join(test_data_folder, "detect.tflite")
+ model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
backends = ["CpuAcc", "CpuRef"]
executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
- input_tensors = cv_utils.preprocess(img, executor.input_binding_info)
+ input_tensors = cv_utils.preprocess(img, executor.input_binding_info, True)
output_result = executor.run(input_tensors)
diff --git a/python/pyarmnn/examples/tests/testdata/labelmap.txt b/python/pyarmnn/examples/tests/testdata/labelmap.txt
new file mode 100644
index 0000000000..444850dc8b
--- /dev/null
+++ b/python/pyarmnn/examples/tests/testdata/labelmap.txt
@@ -0,0 +1,9 @@
+person
+motorcycle
+airplane
+bicycle
+train
+boat
+truck
+bus
+