summaryrefslogtreecommitdiff
path: root/scripts/py/gen_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/py/gen_utils.py')
-rw-r--r--scripts/py/gen_utils.py194
1 files changed, 141 insertions, 53 deletions
diff --git a/scripts/py/gen_utils.py b/scripts/py/gen_utils.py
index ee33705..6bb4760 100644
--- a/scripts/py/gen_utils.py
+++ b/scripts/py/gen_utils.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,21 +14,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import soundfile as sf
-import resampy
+"""
+Utility functions for .cc + .hpp file generation
+"""
+import argparse
+import datetime
+from dataclasses import dataclass
+from pathlib import Path
+
+import jinja2
import numpy as np
+import resampy
+import soundfile as sf
-class AudioUtils:
+@dataclass
+class AudioSample:
+ """
+ Represents an audio sample with its sample rate
+ """
+ data: np.ndarray
+ sample_rate: int
+
+
+class GenUtils:
+ """
+ Class with utility functions for audio and other .cc + .hpp file generation
+ """
+
@staticmethod
def res_data_type(res_type_value):
"""
Returns the input string if is one of the valid resample type
"""
- import argparse
- if res_type_value not in AudioUtils.res_type_list():
- raise argparse.ArgumentTypeError(f"{res_type_value} not valid. Supported only {AudioUtils.res_type_list()}")
+ if res_type_value not in GenUtils.res_type_list():
+ raise argparse.ArgumentTypeError(
+ f"{res_type_value} not valid. Supported only {GenUtils.res_type_list()}"
+ )
return res_type_value
@staticmethod
@@ -39,27 +61,18 @@ class AudioUtils:
return ['kaiser_best', 'kaiser_fast']
@staticmethod
- def load_resample_audio_clip(path, target_sr=16000, mono=True, offset=0.0, duration=0, res_type='kaiser_best',
- min_len=16000):
+ def read_audio_file(
+ path,
+ offset,
+ duration
+ ) -> AudioSample:
"""
- Load and resample an audio clip with the given desired specs.
+ Reads an audio file to an array
- Parameters:
- ----------
- path (string): Path to the input audio clip.
- target_sr (int, optional): Target sampling rate. Positive number are considered valid,
- if zero or negative the native sampling rate of the file will be preserved. Default is 16000.
- mono (bool, optional): Specify if the audio file needs to be converted to mono. Default is True.
- offset (float, optional): Target sampling rate. Default is 0.0.
- duration (int, optional): Target duration. Positive number are considered valid,
- if zero or negative the duration of the file will be preserved. Default is 0.
- res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
- min_len (int, optional): Minimun lenght of the output audio time series. Default is 16000.
-
- Returns:
- ----------
- y (np.ndarray): Output audio time series of shape shape=(n,) or (2, n).
- sr (int): A scalar number > 0 that represent the sampling rate of `y`
+ @param path: Path to audio file
+ @param offset: Offset to read from
+ @param duration: Duration to read
+ @return: The audio data and the sample rate
"""
try:
with sf.SoundFile(path) as audio_file:
@@ -76,40 +89,115 @@ class AudioUtils:
# Load the target number of frames
y = audio_file.read(frames=num_frame_duration, dtype=np.float32, always_2d=False).T
-
- except:
+ except OSError as err:
print(f"Failed to open {path} as an audio.")
+ raise err
+
+ return AudioSample(y, origin_sr)
+
+ @staticmethod
+ def _resample_audio(
+ y,
+ target_sr,
+ origin_sr,
+ res_type
+ ):
+ """
+ Resamples audio to a different sample rate
+
+ @param y: Audio to resample
+ @param target_sr: Target sample rate
+ @param origin_sr: Original sample rate
+ @param res_type: Resample type
+ @return: The resampled audio
+ """
+ ratio = float(target_sr) / origin_sr
+ axis = -1
+ n_samples = int(np.ceil(y.shape[axis] * ratio))
+
+ # Resample using resampy
+ y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
+ n_rs_samples = y_rs.shape[axis]
+
+ # Adjust the size
+ if n_rs_samples > n_samples:
+ slices = [slice(None)] * y_rs.ndim
+ slices[axis] = slice(0, n_samples)
+ y = y_rs[tuple(slices)]
+ elif n_rs_samples < n_samples:
+ lengths = [(0, 0)] * y_rs.ndim
+ lengths[axis] = (0, n_samples - n_rs_samples)
+ y = np.pad(y_rs, lengths, 'constant', constant_values=0)
+
+ return y
+
+ @staticmethod
+ def resample_audio_clip(
+ audio_sample: AudioSample,
+ target_sr=16000,
+ mono=True,
+ res_type='kaiser_best',
+ min_len=16000
+ ) -> AudioSample:
+ """
+ Load and resample an audio clip with the given desired specs.
+
+ Parameters:
+ ----------
+ path (string): Path to the input audio clip.
+ target_sr (int, optional): Target sampling rate. Positive number are considered valid,
+ if zero or negative the native sampling rate of the file
+ will be preserved. Default is 16000.
+ mono (bool, optional): Specify if the audio file needs to be converted to mono.
+ Default is True.
+ offset (float, optional): Target sampling rate. Default is 0.0.
+ duration (int, optional): Target duration. Positive number are considered valid,
+ if zero or negative the duration of the file
+ will be preserved. Default is 0.
+ res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
+ min_len (int, optional): Minimum length of the output audio time series.
+ Default is 16000.
+
+ Returns:
+ ----------
+ y (np.ndarray): Output audio time series of shape=(n,) or (2, n).
+ sample_rate (int): A scalar number > 0 that represent the sampling rate of `y`
+ """
+ y = audio_sample.data.copy()
# Convert to mono if requested and if audio has more than one dimension
- if mono and (y.ndim > 1):
+ if mono and (audio_sample.data.ndim > 1):
y = np.mean(y, axis=0)
- if not (origin_sr == target_sr) and (target_sr > 0):
- ratio = float(target_sr) / origin_sr
- axis = -1
- n_samples = int(np.ceil(y.shape[axis] * ratio))
-
- # Resample using resampy
- y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
- n_rs_samples = y_rs.shape[axis]
-
- # Adjust the size
- if n_rs_samples > n_samples:
- slices = [slice(None)] * y_rs.ndim
- slices[axis] = slice(0, n_samples)
- y = y_rs[tuple(slices)]
- elif n_rs_samples < n_samples:
- lengths = [(0, 0)] * y_rs.ndim
- lengths[axis] = (0, n_samples - n_rs_samples)
- y = np.pad(y_rs, lengths, 'constant', constant_values=(0))
-
- sr = target_sr
+ if not (audio_sample.sample_rate == target_sr) and (target_sr > 0):
+ y = GenUtils._resample_audio(y, target_sr, audio_sample.sample_rate, res_type)
+ sample_rate = target_sr
else:
- sr = origin_sr
+ sample_rate = audio_sample.sample_rate
# Pad if necessary and min lenght is setted (min_len> 0)
if (y.shape[0] < min_len) and (min_len > 0):
sample_to_pad = min_len - y.shape[0]
- y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=(0))
+ y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=0)
+
+ return AudioSample(data=y, sample_rate=sample_rate)
- return y, sr
+ @staticmethod
+ def gen_header(
+ env: jinja2.Environment,
+ header_template_file: str,
+ file_name: str = None
+ ) -> str:
+ """
+ Generate common licence header
+
+ :param env: Jinja2 environment
+ :param header_template_file: Path to the licence header template
+ :param file_name: Optional generating script file name
+ :return: Generated licence header as a string
+ """
+ header_template = env.get_template(header_template_file)
+ return header_template.render(script_name=Path(__file__).name,
+ gen_time=datetime.datetime.now(),
+ file_name=file_name,
+ year=datetime.datetime.now().year)