diff options
author | alexander <alexander.efremov@arm.com> | 2021-07-16 11:30:56 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2022-02-04 09:55:21 +0000 |
commit | f42f56870c6201a876f025a423eb5540d7438e83 (patch) | |
tree | e8e57e371c851cbb9a51a2f3ec35059addd2e93e /python/pyarmnn/examples/keyword_spotting/run_audio_classification.py | |
parent | 9d74ba6e85a043e9603445e062315f5c4965fbd6 (diff) | |
download | armnn-f42f56870c6201a876f025a423eb5540d7438e83.tar.gz |
MLECO-2079 Adding the python KWS example
Signed-off-by: Eanna O Cathain <eanna.ocathain@arm.com>
Change-Id: Ie1463aaeb5e3cade22df8f560ae99a8e1c4a9c17
Diffstat (limited to 'python/pyarmnn/examples/keyword_spotting/run_audio_classification.py')
-rw-r--r-- | python/pyarmnn/examples/keyword_spotting/run_audio_classification.py | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py new file mode 100644 index 0000000000..6dfa4cc806 --- /dev/null +++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py @@ -0,0 +1,136 @@ +# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files.""" + +import sys +import os +from argparse import ArgumentParser + +import numpy as np +import sounddevice as sd + +script_dir = os.path.dirname(__file__) +sys.path.insert(1, os.path.join(script_dir, '..', 'common')) + +from network_executor import ArmnnNetworkExecutor +from utils import prepare_input_tensors, dequantize_output +from mfcc import AudioPreprocessor, MFCC, MFCCParams +from audio_utils import decode, display_text +from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio + +# Model Specific Labels +labels = {0: 'silence', + 1: 'unknown', + 2: 'yes', + 3: 'no', + 4: 'up', + 5: 'down', + 6: 'left', + 7: 'right', + 8: 'on', + 9: 'off', + 10: 'stop', + 11: 'go'} + + +def parse_args(): + parser = ArgumentParser(description="KWS with PyArmNN") + parser.add_argument( + "--audio_file_path", + required=False, + type=str, + help="Path to the audio file to perform KWS", + ) + parser.add_argument( + "--duration", + type=int, + default=0, + help="""Duration for recording audio in seconds. Values <= 0 result in infinite + recording. Defaults to infinite.""", + ) + parser.add_argument( + "--model_file_path", + required=True, + type=str, + help="Path to KWS model to use", + ) + parser.add_argument( + "--preferred_backends", + type=str, + nargs="+", + default=["CpuAcc", "CpuRef"], + help="""List of backends in order of preference for optimizing + subgraphs, falling back to the next backend in the list on unsupported + layers. Defaults to [CpuAcc, CpuRef]""", + ) + return parser.parse_args() + + +def recognise_speech(audio_data, network, preprocessor, threshold): + # Prepare the input Tensors + input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor) + # Run inference + output_result = network.run(input_tensors) + + dequantized_result = [] + for index, ofm in enumerate(output_result): + dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index])) + + # Decode the text and display result if above threshold + decoded_result = decode(dequantized_result, labels) + + if decoded_result[1] > threshold: + display_text(decoded_result) + + +def main(args): + # Read command line args and invoke mic streaming if no file path supplied + audio_file = args.audio_file_path + if args.audio_file_path: + streaming_enabled = False + else: + streaming_enabled = True + # Create the ArmNN inference runner + network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends) + + # Specify model specific audio data requirements + # Overlap value specifies the number of samples to rewind between each data window + audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000, + mono=True) + + # Create the preprocessor + mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000, + num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024) + mfcc = MFCC(mfcc_params) + preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320) + + # Set threshold for displaying classification and commence stream or file processing + threshold = .90 + if streaming_enabled: + # Initialise audio stream + record_stream = CaptureAudioStream(audio_capture_params) + record_stream.set_stream_defaults() + record_stream.set_recording_duration(args.duration) + record_stream.countdown() + + with sd.InputStream(callback=record_stream.callback): + print("Recording audio. Please speak.") + while record_stream.is_active: + + audio_data = record_stream.capture_data() + recognise_speech(audio_data, network, preprocessor, threshold) + record_stream.is_first_window = False + print("\nFinished recording.") + + # If file path has been supplied read-in and run inference + else: + print("Processing Audio Frames...") + buffer = capture_audio(audio_file, audio_capture_params) + for audio_data in buffer: + recognise_speech(audio_data, network, preprocessor, threshold) + + +if __name__ == "__main__": + args = parse_args() + main(args) |