# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT """Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files.""" import sys import os from argparse import ArgumentParser import numpy as np import sounddevice as sd script_dir = os.path.dirname(__file__) sys.path.insert(1, os.path.join(script_dir, '..', 'common')) from network_executor import ArmnnNetworkExecutor from utils import prepare_input_tensors, dequantize_output from mfcc import AudioPreprocessor, MFCC, MFCCParams from audio_utils import decode, display_text from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio # Model Specific Labels labels = {0: 'silence', 1: 'unknown', 2: 'yes', 3: 'no', 4: 'up', 5: 'down', 6: 'left', 7: 'right', 8: 'on', 9: 'off', 10: 'stop', 11: 'go'} def parse_args(): parser = ArgumentParser(description="KWS with PyArmNN") parser.add_argument( "--audio_file_path", required=False, type=str, help="Path to the audio file to perform KWS", ) parser.add_argument( "--duration", type=int, default=0, help="""Duration for recording audio in seconds. Values <= 0 result in infinite recording. Defaults to infinite.""", ) parser.add_argument( "--model_file_path", required=True, type=str, help="Path to KWS model to use", ) parser.add_argument( "--preferred_backends", type=str, nargs="+", default=["CpuAcc", "CpuRef"], help="""List of backends in order of preference for optimizing subgraphs, falling back to the next backend in the list on unsupported layers. Defaults to [CpuAcc, CpuRef]""", ) return parser.parse_args() def recognise_speech(audio_data, network, preprocessor, threshold): # Prepare the input Tensors input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor) # Run inference output_result = network.run(input_tensors) dequantized_result = [] for index, ofm in enumerate(output_result): dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index])) # Decode the text and display result if above threshold decoded_result = decode(dequantized_result, labels) if decoded_result[1] > threshold: display_text(decoded_result) def main(args): # Read command line args and invoke mic streaming if no file path supplied audio_file = args.audio_file_path if args.audio_file_path: streaming_enabled = False else: streaming_enabled = True # Create the ArmNN inference runner network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends) # Specify model specific audio data requirements # Overlap value specifies the number of samples to rewind between each data window audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000, mono=True) # Create the preprocessor mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000, num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024) mfcc = MFCC(mfcc_params) preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320) # Set threshold for displaying classification and commence stream or file processing threshold = .90 if streaming_enabled: # Initialise audio stream record_stream = CaptureAudioStream(audio_capture_params) record_stream.set_stream_defaults() record_stream.set_recording_duration(args.duration) record_stream.countdown() with sd.InputStream(callback=record_stream.callback): print("Recording audio. Please speak.") while record_stream.is_active: audio_data = record_stream.capture_data() recognise_speech(audio_data, network, preprocessor, threshold) record_stream.is_first_window = False print("\nFinished recording.") # If file path has been supplied read-in and run inference else: print("Processing Audio Frames...") buffer = capture_audio(audio_file, audio_capture_params) for audio_data in buffer: recognise_speech(audio_data, network, preprocessor, threshold) if __name__ == "__main__": args = parse_args() main(args)