diff options
Diffstat (limited to 'python/pyarmnn/examples/common/audio_capture.py')
-rw-r--r-- | python/pyarmnn/examples/common/audio_capture.py | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/audio_capture.py b/python/pyarmnn/examples/common/audio_capture.py new file mode 100644 index 0000000000..1bd53b4473 --- /dev/null +++ b/python/pyarmnn/examples/common/audio_capture.py @@ -0,0 +1,149 @@ +# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT +"""Contains CaptureAudioStream class for capturing chunks of audio data from incoming + stream and generic capture_audio function for capturing from files.""" +import collections +import time +from queue import Queue +from typing import Generator + +import numpy as np +import sounddevice as sd +import soundfile as sf + +AudioCaptureParams = collections.namedtuple('AudioCaptureParams', + ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono']) + + +def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]: + """Creates a generator that yields audio data from a file. Data is padded with + zeros if necessary to make up minimum number of samples. + Args: + audio_file_path: Path to audio file provided by user. + params_tuple: Sampling parameters for model used + Yields: + Blocks of audio data of minimum sample size. + """ + with sf.SoundFile(audio_file_path) as audio_file: + for block in audio_file.blocks( + blocksize=params_tuple.min_samples, + dtype=params_tuple.dtype, + always_2d=True, + fill_value=0, + overlap=params_tuple.overlap + ): + if params_tuple.mono and block.shape[0] > 1: + block = np.mean(block, dtype=block.dtype, axis=1) + yield block + + +class CaptureAudioStream: + + def __init__(self, audio_capture_params): + self.audio_capture_params = audio_capture_params + self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype( + dtype=self.audio_capture_params.dtype) + self.is_active = True + self.is_first_window = True + self.duration = False + self.block_count = 0 + self.current_block = 0 + self.queue = Queue(2) + + def set_stream_defaults(self): + """Discovers input devices on the system and sets default stream parameters.""" + print(sd.query_devices()) + device = input("Select input device by index or name: ") + + try: + sd.default.device = int(device) + except ValueError: + sd.default.device = str(device) + + sd.default.samplerate = self.audio_capture_params.sampling_freq + sd.default.blocksize = self.audio_capture_params.min_samples + sd.default.dtype = self.audio_capture_params.dtype + sd.default.channels = 1 if self.audio_capture_params.mono else 2 + + def set_recording_duration(self, duration): + """Sets a time duration (in integer seconds) for recording audio. Total time duration is + adjusted to a minimum based on the parameters of the model used. Durations less than 1 + result in endless recording. + + Args: + duration (int): User-provided command line argument for time duration of recording. + """ + if duration > 0: + min_duration = int( + np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq) + ) + if duration < min_duration: + print(f"Minimum duration must be {min_duration} seconds of audio") + print(f"Setting minimum recording duration...") + duration = min_duration + + print(f"Recording duration is {duration} seconds") + self.duration = self.audio_capture_params.sampling_freq * duration + self.block_count, remainder_samples = divmod( + self.duration, self.audio_capture_params.min_samples + ) + + if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq: + self.block_count += 1 + else: + self.duration = False # Record forever + + def countdown(self, delay=3): + """3 second countdown prior to recording audio.""" + print("Beginning recording in...") + for i in range(delay, 0, -1): + print(f"{i}...") + time.sleep(1) + + def update(self): + """If a duration has been set, increments a counter to update the number of blocks of audio + data left to be collected. The stream is deactivated upon reaching the maximum block count + determined by the duration. + """ + if self.duration: + self.current_block += 1 + if self.current_block == self.block_count: + self.is_active = False + + def capture_data(self): + """Gets the next window of audio data by retrieving the newest data from a queue and + shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported. + """ + new_data = self.queue.get() + + if self.is_first_window or self.audio_capture_params.overlap == 0: + self.collection[:self.audio_capture_params.min_samples] = new_data[:] + + elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples: + # + self.collection[0:self.audio_capture_params.overlap] = \ + self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap): + self.audio_capture_params.min_samples] + + self.collection[self.audio_capture_params.overlap:( + self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:] + else: + raise ValueError( + "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples)) + audio_data = self.collection[0:self.audio_capture_params.min_samples] + return np.asarray(audio_data).astype(self.audio_capture_params.dtype) + + def callback(self, data, frames, time, status): + """Places audio data from active stream into a queue for processing. + Update counter if recording duration is finite. + """ + + if self.duration: + self.update() + + if self.audio_capture_params.mono: + audio_data = data.copy().flatten() + else: + audio_data = data.copy() + + self.queue.put(audio_data) |