summaryrefslogtreecommitdiff
path: root/scripts/py
diff options
context:
space:
mode:
authorAlex Tawse <alex.tawse@arm.com>2023-09-29 15:55:38 +0100
committerRichard <richard.burton@arm.com>2023-10-26 12:35:48 +0000
commitdaba3cf2e3633cbd0e4f8aabe7578b97e88deee1 (patch)
tree51024b8025e28ecb2aecd67246e189e25f5a6e6c /scripts/py
parenta11976fb866f77305708f832e603b963969e6a14 (diff)
downloadml-embedded-evaluation-kit-daba3cf2e3633cbd0e4f8aabe7578b97e88deee1.tar.gz
MLECO-3995: Pylint + Shellcheck compatibility
* All Python scripts updated to abide by Pylint rules * good-names updated to permit short variable names: i, j, k, f, g, ex * ignore-long-lines regex updated to allow long lines for licence headers * Shell scripts now compliant with Shellcheck Signed-off-by: Alex Tawse <Alex.Tawse@arm.com> Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5
Diffstat (limited to 'scripts/py')
-rw-r--r--scripts/py/check_update_resources_downloaded.py54
-rw-r--r--scripts/py/dependency_urls.json8
-rw-r--r--scripts/py/gen_audio.py107
-rw-r--r--scripts/py/gen_audio_cpp.py258
-rw-r--r--scripts/py/gen_default_input_cpp.py49
-rw-r--r--scripts/py/gen_labels_cpp.py74
-rw-r--r--scripts/py/gen_model_cpp.py89
-rw-r--r--scripts/py/gen_rgb_cpp.py203
-rw-r--r--scripts/py/gen_test_data_cpp.py317
-rw-r--r--scripts/py/gen_utils.py194
-rwxr-xr-xscripts/py/git_pre_push_hooks.sh48
-rw-r--r--scripts/py/rnnoise_dump_extractor.py79
-rw-r--r--scripts/py/setup_hooks.py109
-rw-r--r--scripts/py/templates/header_template.txt2
-rw-r--r--scripts/py/use_case_resources.json190
15 files changed, 1337 insertions, 444 deletions
diff --git a/scripts/py/check_update_resources_downloaded.py b/scripts/py/check_update_resources_downloaded.py
index 6e4da21..bdd9d62 100644
--- a/scripts/py/check_update_resources_downloaded.py
+++ b/scripts/py/check_update_resources_downloaded.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+Contains methods to check if the downloaded resources need to be refreshed
+"""
+import hashlib
import json
import sys
-import hashlib
+import typing
from argparse import ArgumentParser
from pathlib import Path
-def get_md5sum_for_file(filepath: str) -> str:
+def get_md5sum_for_file(
+ filepath: typing.Union[str, Path]
+) -> str:
"""
Function to calculate md5sum for contents of a given file.
@@ -41,7 +46,7 @@ def get_md5sum_for_file(filepath: str) -> str:
def check_update_resources_downloaded(
- resource_downloaded_dir: str, set_up_script_path: str
+ resource_downloaded_dir: str, set_up_script_path: str
):
"""
Function that check if the resources downloaded need to be refreshed.
@@ -55,27 +60,27 @@ def check_update_resources_downloaded(
metadata_file_path = Path(resource_downloaded_dir) / "resources_downloaded_metadata.json"
if metadata_file_path.is_file():
- with open(metadata_file_path) as metadata_json:
-
+ with open(metadata_file_path, encoding="utf8") as metadata_json:
metadata_dict = json.load(metadata_json)
- md5_key = 'set_up_script_md5sum'
- set_up_script_md5sum_metadata = ''
- if md5_key in metadata_dict.keys():
- set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"]
+ md5_key = 'set_up_script_md5sum'
+ set_up_script_md5sum_metadata = ''
+
+ if md5_key in metadata_dict.keys():
+ set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"]
- set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path)
+ set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path)
- if set_up_script_md5sum_current == set_up_script_md5sum_metadata:
- return 0
+ if set_up_script_md5sum_current == set_up_script_md5sum_metadata:
+ return 0
- # Return code 1 if the resources need to be refreshed.
- print('Error: hash mismatch!')
- print(f'Metadata: {set_up_script_md5sum_metadata}')
- print(f'Current : {set_up_script_md5sum_current}')
- return 1
+ # Return code 1 if the resources need to be refreshed.
+ print('Error: hash mismatch!')
+ print(f'Metadata: {set_up_script_md5sum_metadata}')
+ print(f'Current : {set_up_script_md5sum_current}')
+ return 1
- # Return error code 2 if the file doesn't exists.
+ # Return error code 2 if the file doesn't exist.
print(f'Error: could not find {metadata_file_path}')
return 2
@@ -99,7 +104,8 @@ if __name__ == "__main__":
raise ValueError(f'Invalid script path: {args.setup_script_path}')
# Check the resources are downloaded as expected
- status = check_update_resources_downloaded(
- args.resource_downloaded_dir,
- args.setup_script_path)
- sys.exit(status)
+ STATUS = check_update_resources_downloaded(
+ args.resource_downloaded_dir,
+ args.setup_script_path
+ )
+ sys.exit(STATUS)
diff --git a/scripts/py/dependency_urls.json b/scripts/py/dependency_urls.json
new file mode 100644
index 0000000..33a84f7
--- /dev/null
+++ b/scripts/py/dependency_urls.json
@@ -0,0 +1,8 @@
+{
+ "cmsis": "https://github.com/ARM-software/CMSIS_5/archive/a75f01746df18bb5b929dfb8dc6c9407fac3a0f3.zip",
+ "cmsis-dsp": "https://github.com/ARM-software/CMSIS-DSP/archive/refs/tags/v1.15.0.zip",
+ "cmsis-nn": "https://github.com/ARM-software/CMSIS-NN/archive/refs/85164a811917770d7027a12a57ed3b469dac6537.zip",
+ "core-driver": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/snapshot/ethos-u-core-driver-23.08.tar.gz",
+ "core-platform": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-23.08.tar.gz",
+ "tensorflow": "https://github.com/tensorflow/tflite-micro/archive/568d181ccc1f60e49742fd43b7f97141ee8d45fc.zip"
+}
diff --git a/scripts/py/gen_audio.py b/scripts/py/gen_audio.py
index ff33bfb..4d7318c 100644
--- a/scripts/py/gen_audio.py
+++ b/scripts/py/gen_audio.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,34 +17,99 @@
"""
Utility script to convert an audio clip into eval platform desired spec.
"""
-import soundfile as sf
-
from argparse import ArgumentParser
from os import path
-from gen_utils import AudioUtils
+import soundfile as sf
+
+from gen_utils import GenUtils
parser = ArgumentParser()
-parser.add_argument("--audio_path", help="Audio file path", required=True)
-parser.add_argument("--output_dir", help="Output directory", required=True)
-parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000)
-parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True)
-parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0)
-parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0)
-parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.", default='kaiser_best')
-parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000)
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+
+# pylint: disable=duplicate-code
+parser.add_argument(
+ "--audio_path",
+ help="Audio file path",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory",
+ required=True
+)
+
+parser.add_argument(
+ "--sampling_rate",
+ type=int,
+ help="target sampling rate.",
+ default=16000
+)
+
+parser.add_argument(
+ "--mono",
+ type=bool,
+ help="convert signal to mono.",
+ default=True
+)
+
+parser.add_argument(
+ "--offset",
+ type=float,
+ help="start reading after this time (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--duration",
+ type=float,
+ help="only load up to this much audio (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--res_type",
+ type=GenUtils.res_data_type,
+ help=f"Resample type: {GenUtils.res_type_list()}.",
+ default='kaiser_best'
+)
+
+parser.add_argument(
+ "--min_samples",
+ type=int,
+ help="Minimum sample number.",
+ default=16000
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+# pylint: enable=duplicate-code
+
+parsed_args = parser.parse_args()
def main(args):
- audio_data, samplerate = AudioUtils.load_resample_audio_clip(args.audio_path,
- args.sampling_rate,
- args.mono, args.offset,
- args.duration, args.res_type,
- args.min_samples)
- sf.write(path.join(args.output_dir, path.basename(args.audio_path)), audio_data, samplerate)
+ """
+ Generate the new audio file
+ @param args: Parsed args
+ """
+ audio_sample = GenUtils.read_audio_file(
+ args.audio_path, args.offset, args.duration
+ )
+
+ resampled_audio = GenUtils.resample_audio_clip(
+ audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples
+ )
+
+ sf.write(
+ path.join(args.output_dir, path.basename(args.audio_path)),
+ resampled_audio.data,
+ resampled_audio.sample_rate
+ )
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_audio_cpp.py b/scripts/py/gen_audio_cpp.py
index 850a871..89d9ae1 100644
--- a/scripts/py/gen_audio_cpp.py
+++ b/scripts/py/gen_audio_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,93 +21,217 @@ from the cpp files.
import datetime
import glob
import math
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
import numpy as np
from jinja2 import Environment, FileSystemLoader
-from gen_utils import AudioUtils
+from gen_utils import GenUtils, AudioSample
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--audio_path", type=str, help="path to audio folder to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000)
-parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True)
-parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0)
-parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0)
-parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.",
- default='kaiser_best')
-parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000)
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--audio_path",
+ type=str,
+ help="path to audio folder to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--sampling_rate",
+ type=int,
+ help="target sampling rate.",
+ default=16000
+)
+
+parser.add_argument(
+ "--mono",
+ type=bool,
+ help="convert signal to mono.",
+ default=True
+)
+
+parser.add_argument(
+ "--offset",
+ type=float,
+ help="start reading after this time (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--duration",
+ type=float,
+ help="only load up to this much audio (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--res_type",
+ type=GenUtils.res_data_type,
+ help=f"Resample type: {GenUtils.res_type_list()}.",
+ default='kaiser_best'
+)
+
+parser.add_argument(
+ "--min_samples",
+ type=int,
+ help="Minimum sample number.",
+ default=16000
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_filepath, cc_filepath, header_template_file, num_audios, audio_filenames, audio_array_namesizes):
+# pylint: enable=duplicate-code
+def write_hpp_file(
+ header_filepath,
+ header,
+ num_audios,
+ audio_array_namesizes
+):
+ """
+ Write audio hpp file
+
+ @param header_filepath: .hpp filepath
+ @param header: Rendered header
+ @param num_audios: Audio file index
+ @param audio_array_namesizes: Audio array name sizes
+ """
print(f"++ Generating {header_filepath}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('AudioClips.hpp.template').stream(common_template_header=hdr,
- clips_count=num_audios,
- varname_size=audio_array_namesizes
- ) \
+ env \
+ .get_template('AudioClips.hpp.template') \
+ .stream(common_template_header=header,
+ clips_count=num_audios,
+ varname_size=audio_array_namesizes) \
.dump(str(header_filepath))
+
+def write_cc_file(
+ cc_filepath,
+ header,
+ num_audios,
+ audio_filenames,
+ audio_array_namesizes
+):
+ """
+ Write cc file
+
+ @param cc_filepath: .cc filepath
+ @param header: Rendered header
+ @param num_audios: Audio file index
+ @param audio_filenames: Audio filenames
+ @param audio_array_namesizes: Audio array name sizes
+ """
print(f"++ Generating {cc_filepath}")
- env.get_template('AudioClips.cc.template').stream(common_template_header=hdr,
- clips_count=num_audios,
- var_names=(name for name, _ in audio_array_namesizes),
- clip_sizes=(size for _, size in audio_array_namesizes),
- clip_names=audio_filenames) \
+ env \
+ .get_template('AudioClips.cc.template') \
+ .stream(common_template_header=header,
+ clips_count=num_audios,
+ var_names=(name for name, _ in audio_array_namesizes),
+ clip_sizes=(size for _, size in audio_array_namesizes),
+ clip_names=audio_filenames) \
.dump(str(cc_filepath))
-def write_individual_audio_cc_file(clip_dirpath, clip_filename,
- cc_filename, header_template_file, array_name,
- sampling_rate_value, mono_value, offset_value,
- duration_value, res_type_value, min_len):
+def write_individual_audio_cc_file(
+ resampled_audio: AudioSample,
+ clip_filename,
+ cc_filename,
+ header_template_file,
+ array_name
+):
+ """
+ Writes the provided audio sample to a .cc file
+
+ @param resampled_audio: Audio sample to write
+ @param clip_filename: File name of the clip
+ @param cc_filename: File name of the .cc file
+ @param header_template_file: Header template
+ @param array_name: Name of the array to write
+ @return: Array length of the audio data written
+ """
print(f"++ Converting {clip_filename} to {Path(cc_filename).name}")
- audio_filepath = Path(clip_dirpath) / clip_filename
- clip_data, samplerate = AudioUtils.load_resample_audio_clip(audio_filepath,
- sampling_rate_value, mono_value,
- offset_value, duration_value,
- res_type_value, min_len)
# Change from [-1, 1] fp32 range to int16 range.
- clip_data = np.clip((clip_data * (1 << 15)),
+ clip_data = np.clip((resampled_audio.data * (1 << 15)),
np.iinfo(np.int16).min,
np.iinfo(np.int16).max).flatten().astype(np.int16)
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=clip_filename,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, header_template_file, clip_filename)
hex_line_generator = (', '.join(map(hex, sub_arr))
- for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data)/20)))
+ for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data) / 20)))
- env.get_template('audio.cc.template').stream(common_template_header=hdr,
- size=len(clip_data),
- var_name=array_name,
- audio_data=hex_line_generator) \
+ env \
+ .get_template('audio.cc.template') \
+ .stream(common_template_header=hdr,
+ size=len(clip_data),
+ var_name=array_name,
+ audio_data=hex_line_generator) \
.dump(str(cc_filename))
return len(clip_data)
+def create_audio_cc_file(args, filename, array_name, clip_dirpath):
+ """
+ Create an individual audio cpp file
+
+ @param args: User-specified args
+ @param filename: Audio filename
+ @param array_name: Name of the array in the audio .cc file
+ @param clip_dirpath: Audio file directory path
+ @return: Array length of the audio data written
+ """
+ cc_filename = (Path(args.source_folder_path) /
+ (Path(filename).stem.replace(" ", "_") + ".cc"))
+ audio_filepath = Path(clip_dirpath) / filename
+ audio_sample = GenUtils.read_audio_file(audio_filepath, args.offset, args.duration)
+ resampled_audio = GenUtils.resample_audio_clip(
+ audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples
+ )
+ return write_individual_audio_cc_file(
+ resampled_audio, filename, cc_filename, args.license_template, array_name,
+ )
+
+
def main(args):
+ """
+ Convert audio files to .cc + .hpp files
+ @param args: Parsed args
+ """
# Keep the count of the audio files converted
audioclip_idx = 0
audioclip_filenames = []
@@ -131,25 +255,41 @@ def main(args):
audioclip_filenames.append(filename)
# Save the cc file
- cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc")
array_name = "audio" + str(audioclip_idx)
- array_size = write_individual_audio_cc_file(clip_dirpath, filename, cc_filename, args.license_template, array_name,
- args.sampling_rate, args.mono, args.offset,
- args.duration, args.res_type, args.min_samples)
+ array_size = create_audio_cc_file(args, filename, array_name, clip_dirpath)
audioclip_array_names.append((array_name, array_size))
# Increment audio index
audioclip_idx = audioclip_idx + 1
- except:
+ except OSError:
if args.verbosity:
print(f"Failed to open {filename} as an audio.")
if len(audioclip_filenames) > 0:
- write_hpp_file(header_filepath, common_cc_filepath, args.license_template,
- audioclip_idx, audioclip_filenames, audioclip_array_names)
+ header = env \
+ .get_template(args.license_template) \
+ .render(script_name=Path(__file__).name,
+ gen_time=datetime.datetime.now(),
+ year=datetime.datetime.now().year)
+
+ write_hpp_file(
+ header_filepath,
+ header,
+ audioclip_idx,
+ audioclip_array_names
+ )
+
+ write_cc_file(
+ common_cc_filepath,
+ header,
+ audioclip_idx,
+ audioclip_filenames,
+ audioclip_array_names
+ )
+
else:
raise FileNotFoundError("No valid audio clip files found.")
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_default_input_cpp.py b/scripts/py/gen_default_input_cpp.py
index 093a606..6056dc1 100644
--- a/scripts/py/gen_default_input_cpp.py
+++ b/scripts/py/gen_default_input_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,38 +16,61 @@
"""
Utility script to generate the minimum InputFiles.hpp and cpp files required by an application.
"""
-import datetime
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
parser = ArgumentParser()
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+
+# pylint: disable=duplicate-code
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def write_hpp_file(header_file_path, header_template_file):
+ """
+ Write .hpp file
+ @param header_file_path: Header file path
+ @param header_template_file: Header template file
+ """
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('default.hpp.template').stream(common_template_header=hdr) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+ env \
+ .get_template('default.hpp.template') \
+ .stream(common_template_header=hdr) \
.dump(str(header_file_path))
def main(args):
+ """
+ Generate InputFiles.hpp + .cpp
+ @param args: Parsed args
+ """
header_filename = "InputFiles.hpp"
header_filepath = Path(args.header_folder_path) / header_filename
write_hpp_file(header_filepath, args.license_template)
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_labels_cpp.py b/scripts/py/gen_labels_cpp.py
index 065ed5d..11d5040 100644
--- a/scripts/py/gen_labels_cpp.py
+++ b/scripts/py/gen_labels_cpp.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,47 +21,83 @@ NN model output vector) into a vector list initialiser. The intention is for
this script to be called as part of the build framework to auto-generate the
cpp file with labels that can be used in the application without modification.
"""
-import datetime
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
# Label file path
-parser.add_argument("--labels_file", type=str, help="Path to the label text file", required=True)
+parser.add_argument(
+ "--labels_file",
+ type=str,
+ help="Path to the label text file",
+ required=True
+)
+
# Output file to be generated
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.", required=True)
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.", required=True)
-parser.add_argument("--output_file_name", type=str, help="Required output file name", required=True)
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated.",
+ required=True
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated.",
+ required=True
+)
+
+parser.add_argument(
+ "--output_file_name",
+ type=str,
+ help="Required output file name",
+ required=True
+)
+
# Namespaces
-parser.add_argument("--namespaces", action='append', default=[])
+parser.add_argument(
+ "--namespaces",
+ action='append',
+ default=[]
+)
+
# License template
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
-args = parser.parse_args()
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def main(args):
+ """
+ Generate labels .cpp
+ @param args: Parsed args
+ """
# Get the labels from text file
- with open(args.labels_file, "r") as f:
+ with open(args.labels_file, "r", encoding="utf8") as f:
labels = f.read().splitlines()
# No labels?
if len(labels) == 0:
- raise Exception(f"no labels found in {args.label_file}")
+ raise ValueError(f"no labels found in {args.label_file}")
- header_template = env.get_template(args.license_template)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=Path(args.labels_file).name,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, args.license_template, Path(args.labels_file).name)
hpp_filename = Path(args.header_folder_path) / (args.output_file_name + ".hpp")
env.get_template('Labels.hpp.template').stream(common_template_header=hdr,
@@ -78,4 +114,4 @@ def main(args):
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index e4933b5..933c189 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the
project directly. This should be called as part of cmake framework
should the models need to be generated at configuration stage.
"""
-import datetime
+import binascii
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
-import binascii
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
-parser.add_argument("--output_dir", help="Output directory", required=True)
-parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
-parser.add_argument('--header', action='append', default=[], dest="headers")
-parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+parser.add_argument(
+ "--tflite_path",
+ help="Model (.tflite) path",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory",
+ required=True
+)
+
+parser.add_argument(
+ '-e',
+ '--expression',
+ action='append',
+ default=[],
+ dest="expr"
+)
+
+parser.add_argument(
+ '--header',
+ action='append',
+ default=[],
+ dest="headers"
+)
+
+parser.add_argument(
+ '-ns',
+ '--namespaces',
+ action='append',
+ default=[],
+ dest="namespaces"
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def get_tflite_data(tflite_path: str) -> list:
"""
Reads a binary file and returns a C style array as a
@@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list:
for i in range(0, len(hexstream), 2):
if 0 == (i % hex_digits_per_line):
hexstring += "\n"
- hexstring += '0x' + hexstream[i:i+2] + ", "
+ hexstring += '0x' + hexstream[i:i + 2] + ", "
hexstring += '};\n'
return [hexstring]
def main(args):
+ """
+ Generate models .cpp
+ @param args: Parsed args
+ """
if not Path(args.tflite_path).is_file():
- raise Exception(f"{args.tflite_path} not found")
+ raise ValueError(f"{args.tflite_path} not found")
# Cpp filename:
cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
@@ -80,19 +122,16 @@ def main(args):
cpp_filename.parent.mkdir(exist_ok=True)
- header_template = env.get_template(args.license_template)
-
- hdr = header_template.render(script_name=Path(__file__).name,
- file_name=Path(args.tflite_path).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name)
- env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=get_tflite_data(args.tflite_path),
- expressions=args.expr,
- additional_headers=args.headers,
- namespaces=args.namespaces).dump(str(cpp_filename))
+ env \
+ .get_template('tflite.cc.template') \
+ .stream(common_template_header=hdr,
+ model_data=get_tflite_data(args.tflite_path),
+ expressions=args.expr,
+ additional_headers=args.headers,
+ namespaces=args.namespaces).dump(str(cpp_filename))
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_rgb_cpp.py b/scripts/py/gen_rgb_cpp.py
index b8d85ee..e1c93bb 100644
--- a/scripts/py/gen_rgb_cpp.py
+++ b/scripts/py/gen_rgb_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,87 +18,179 @@ Utility script to convert a set of RGB images in a given location into
corresponding cpp files and a single hpp file referencing the vectors
from the cpp files.
"""
-import datetime
import glob
import math
-from pathlib import Path
+import typing
from argparse import ArgumentParser
+from dataclasses import dataclass
+from pathlib import Path
import numpy as np
from PIL import Image, UnidentifiedImageError
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--image_path", type=str, help="path to images folder or image file to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--image_size", type=int, nargs=2, help="Size (width and height) of the converted images.")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--image_path",
+ type=str,
+ help="path to images folder or image file to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--image_size",
+ type=int,
+ nargs=2,
+ help="Size (width and height) of the converted images."
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_file_path, cc_file_path, header_template_file, num_images, image_filenames,
- image_array_names, image_size):
+# pylint: enable=duplicate-code
+@dataclass
+class ImagesParams:
+ """
+ Template params for Images.hpp and Images.cc
+ """
+ num_images: int
+ image_size: typing.Sequence
+ image_array_names: typing.List[str]
+ image_filenames: typing.List[str]
+
+
+def write_hpp_file(
+ images_params: ImagesParams,
+ header_file_path: Path,
+ cc_file_path: Path,
+ header_template_file: str,
+):
+ """
+ Write Images.hpp and Images.cc
+
+ @param images_params: Template params
+ @param header_file_path: Images.hpp path
+ @param cc_file_path: Images.cc path
+ @param header_template_file: Header template file name
+ """
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('Images.hpp.template').stream(common_template_header=hdr,
- imgs_count=num_images,
- img_size=str(image_size[0] * image_size[1] * 3),
- var_names=image_array_names) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+
+ image_size = str(images_params.image_size[0] * images_params.image_size[1] * 3)
+
+ env \
+ .get_template('Images.hpp.template') \
+ .stream(common_template_header=hdr,
+ imgs_count=images_params.num_images,
+ img_size=image_size,
+ var_names=images_params.image_array_names) \
.dump(str(header_file_path))
- env.get_template('Images.cc.template').stream(common_template_header=hdr,
- var_names=image_array_names,
- img_names=image_filenames) \
+ env \
+ .get_template('Images.cc.template') \
+ .stream(common_template_header=hdr,
+ var_names=images_params.image_array_names,
+ img_names=images_params.image_filenames) \
.dump(str(cc_file_path))
-def write_individual_img_cc_file(image_filename, cc_filename, header_template_file, original_image,
- image_size, array_name):
- print(f"++ Converting {image_filename} to {cc_filename.name}")
+def resize_crop_image(
+ original_image: Image.Image,
+ image_size: typing.Sequence
+) -> np.ndarray:
+ """
+ Resize and crop input image
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=image_filename,
- year=datetime.datetime.now().year)
+ @param original_image: Image to resize and crop
+ @param image_size: New image size
+ @return: Resized and cropped image
+ """
# IFM size
ifm_width = image_size[0]
ifm_height = image_size[1]
# Aspect ratio resize
- scale_ratio = (float)(max(ifm_width, ifm_height)) / (float)(min(original_image.size[0], original_image.size[1]))
- resized_width = (int)(original_image.size[0] * scale_ratio)
- resized_height = (int)(original_image.size[1] * scale_ratio)
- resized_image = original_image.resize([resized_width,resized_height], Image.Resampling.BILINEAR)
+ scale_ratio = (float(max(ifm_width, ifm_height))
+ / float(min(original_image.size[0], original_image.size[1])))
+ resized_width = int(original_image.size[0] * scale_ratio)
+ resized_height = int(original_image.size[1] * scale_ratio)
+ resized_image = original_image.resize(
+ size=(resized_width, resized_height),
+ resample=Image.Resampling.BILINEAR
+ )
# Crop the center of the image
resized_image = resized_image.crop((
- (resized_width - ifm_width) / 2, # left
- (resized_height - ifm_height) / 2, # top
- (resized_width + ifm_width) / 2, # right
+ (resized_width - ifm_width) / 2, # left
+ (resized_height - ifm_height) / 2, # top
+ (resized_width + ifm_width) / 2, # right
(resized_height + ifm_height) / 2 # bottom
- ))
+ ))
+
+ return np.array(resized_image, dtype=np.uint8).flatten()
+
+
+def write_individual_img_cc_file(
+ rgb_data: np.ndarray,
+ image_filename: str,
+ cc_filename: Path,
+ header_template_file: str,
+ array_name: str
+):
+ """
+ Write image.cc
+
+ @param rgb_data: Image data
+ @param image_filename: Image file name
+ @param cc_filename: image.cc path
+ @param header_template_file: Header template file name
+ @param array_name: C++ array name
+ """
+ print(f"++ Converting {image_filename} to {cc_filename.name}")
+
+ hdr = GenUtils.gen_header(env, header_template_file, image_filename)
- # Convert the image and write it to the cc file
- rgb_data = np.array(resized_image, dtype=np.uint8).flatten()
hex_line_generator = (', '.join(map(hex, sub_arr))
for sub_arr in np.array_split(rgb_data, math.ceil(len(rgb_data) / 20)))
- env.get_template('image.cc.template').stream(common_template_header=hdr,
- var_name=array_name,
- img_data=hex_line_generator) \
+ env \
+ .get_template('image.cc.template') \
+ .stream(common_template_header=hdr,
+ var_name=array_name,
+ img_data=hex_line_generator) \
.dump(str(cc_filename))
def main(args):
+ """
+ Convert images
+ @param args: Parsed args
+ """
# Keep the count of the images converted
image_idx = 0
image_filenames = []
@@ -123,26 +215,29 @@ def main(args):
image_filenames.append(filename)
# Save the cc file
- cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc")
+ cc_filename = (Path(args.source_folder_path) /
+ (Path(filename).stem.replace(" ", "_") + ".cc"))
array_name = "im" + str(image_idx)
image_array_names.append(array_name)
- write_individual_img_cc_file(filename, cc_filename, args.license_template,
- original_image, args.image_size, array_name)
+
+ rgb_data = resize_crop_image(original_image, args.image_size)
+ write_individual_img_cc_file(
+ rgb_data, filename, cc_filename, args.license_template, array_name
+ )
# Increment image index
image_idx = image_idx + 1
- header_filename = "InputFiles.hpp"
- header_filepath = Path(args.header_folder_path) / header_filename
- common_cc_filename = "InputFiles.cc"
- common_cc_filepath = Path(args.source_folder_path) / common_cc_filename
+ header_filepath = Path(args.header_folder_path) / "InputFiles.hpp"
+ common_cc_filepath = Path(args.source_folder_path) / "InputFiles.cc"
+
+ images_params = ImagesParams(image_idx, args.image_size, image_array_names, image_filenames)
if len(image_filenames) > 0:
- write_hpp_file(header_filepath, common_cc_filepath, args.license_template,
- image_idx, image_filenames, image_array_names, args.image_size)
+ write_hpp_file(images_params, header_filepath, common_cc_filepath, args.license_template)
else:
raise FileNotFoundError("No valid images found.")
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_test_data_cpp.py b/scripts/py/gen_test_data_cpp.py
index a9e2b75..1ee55ff 100644
--- a/scripts/py/gen_test_data_cpp.py
+++ b/scripts/py/gen_test_data_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 - 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,81 +18,170 @@ Utility script to convert a set of pairs of npy files in a given location into
corresponding cpp files and a single hpp file referencing the vectors
from the cpp files.
"""
-import datetime
import math
-import os
-import numpy as np
+import typing
+from argparse import ArgumentParser
+from dataclasses import dataclass
from pathlib import Path
-from argparse import ArgumentParser
+import numpy as np
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--data_folder_path", type=str, help="path to ifm-ofm npy folder to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--usecase", type=str, default="", help="Test data file suffix.")
-parser.add_argument("--namespaces", action='append', default=[])
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+parser.add_argument(
+ "--data_folder_path",
+ type=str,
+ help="path to ifm-ofm npy folder to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--usecase",
+ type=str,
+ default="",
+ help="Test data file suffix."
+)
+
+parser.add_argument(
+ "--namespaces",
+ action='append',
+ default=[]
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_filename, cc_file_path, header_template_file, num_ifms, num_ofms,
- ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type):
- header_file_path = Path(args.header_folder_path) / header_filename
+# pylint: enable=duplicate-code
+@dataclass
+class TestDataParams:
+ """
+ Template params for TestData.hpp + TestData.ccc
+ """
+ ifm_count: int
+ ofm_count: int
+ ifm_var_names: typing.List[str]
+ ifm_var_sizes: typing.List[int]
+ ofm_var_names: typing.List[str]
+ ofm_var_sizes: typing.List[int]
+ data_type: str
+
+
+@dataclass
+class IofmParams:
+ """
+ Template params for iofmdata.cc
+ """
+ var_name: str
+ data_type: str
+
+
+def write_hpp_file(
+ template_params: TestDataParams,
+ header_filename: str,
+ cc_file_path: str,
+ header_template_file: str
+):
+ """
+ Write TestData.hpp and TestData.cc
+
+ @param template_params: Template parameters
+ @param header_filename: TestData.hpp path
+ @param cc_file_path: TestData.cc path
+ @param header_template_file: Header template file name
+ """
+ header_file_path = Path(parsed_args.header_folder_path) / header_filename
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('TestData.hpp.template').stream(common_template_header=hdr,
- ifm_count=num_ifms,
- ofm_count=num_ofms,
- ifm_var_names=ifm_array_names,
- ifm_var_sizes=ifm_sizes,
- ofm_var_names=ofm_array_names,
- ofm_var_sizes=ofm_sizes,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+ env \
+ .get_template('TestData.hpp.template') \
+ .stream(common_template_header=hdr,
+ ifm_count=template_params.ifm_count,
+ ofm_count=template_params.ofm_count,
+ ifm_var_names=template_params.ifm_var_names,
+ ifm_var_sizes=template_params.ifm_var_sizes,
+ ofm_var_names=template_params.ofm_var_names,
+ ofm_var_sizes=template_params.ofm_var_sizes,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(header_file_path))
- env.get_template('TestData.cc.template').stream(common_template_header=hdr,
- include_h=header_filename,
- ifm_var_names=ifm_array_names,
- ofm_var_names=ofm_array_names,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ env \
+ .get_template('TestData.cc.template') \
+ .stream(common_template_header=hdr,
+ include_h=header_filename,
+ ifm_var_names=template_params.ifm_var_names,
+ ofm_var_names=template_params.ofm_var_names,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(cc_file_path))
-def write_individual_cc_file(filename, cc_filename, header_filename, header_template_file, array_name, iofm_data_type):
+def write_individual_cc_file(
+ template_params: IofmParams,
+ header_filename: str,
+ filename: str,
+ cc_filename: Path,
+ header_template_file: str
+):
+ """
+ Write iofmdata.cc
+
+ @param template_params: Template parameters
+ @param header_filename: Header file name
+ @param filename: Input file name
+ @param cc_filename: iofmdata.cc file name
+ @param header_template_file: Header template file name
+ """
print(f"++ Converting {filename} to {cc_filename.name}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=filename,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, header_template_file, filename)
# Convert the image and write it to the cc file
- fm_data = (np.load(Path(args.data_folder_path) / filename)).flatten()
+ fm_data = (np.load(Path(parsed_args.data_folder_path) / filename)).flatten()
type(fm_data.dtype)
hex_line_generator = (', '.join(map(hex, sub_arr))
for sub_arr in np.array_split(fm_data, math.ceil(len(fm_data) / 20)))
- env.get_template('iofmdata.cc.template').stream(common_template_header=hdr,
- include_h=header_filename,
- var_name=array_name,
- fm_data=hex_line_generator,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ env \
+ .get_template('iofmdata.cc.template') \
+ .stream(common_template_header=hdr,
+ include_h=header_filename,
+ var_name=template_params.var_name,
+ fm_data=hex_line_generator,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(cc_filename))
@@ -104,59 +193,117 @@ def get_npy_vec_size(filename: str) -> int:
Return:
size in bytes
"""
- data = np.load(Path(args.data_folder_path) / filename)
+ data = np.load(Path(parsed_args.data_folder_path) / filename)
return data.size * data.dtype.itemsize
-def main(args):
- # Keep the count of the images converted
- ifm_array_names = []
- ofm_array_names = []
+def write_cc_files(args, count, iofm_data_type, add_usecase_fname, prefix):
+ """
+ Write all cc files
+
+ @param args: User-provided args
+ @param count: File count
+ @param iofm_data_type: Data type
+ @param add_usecase_fname: Use case suffix
+ @param prefix: Prefix (ifm/ofm)
+ @return: Names and sizes of generated C++ arrays
+ """
+ array_names = []
+ sizes = []
+
+ header_filename = get_header_filename(add_usecase_fname)
+ # In the data_folder_path there should be pairs of ifm-ofm
+ # It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy
+ # count = int(len(list(Path(args.data_folder_path).glob(f'{prefix}*.npy'))))
+
+ for idx in range(count):
+ # Save the fm cc file
+ base_name = prefix + str(idx)
+ filename = base_name + ".npy"
+ array_name = base_name + add_usecase_fname
+ cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
+ array_names.append(array_name)
+
+ template_params = IofmParams(
+ var_name=array_name,
+ data_type=iofm_data_type,
+ )
+
+ write_individual_cc_file(
+ template_params, header_filename, filename, cc_filename, args.license_template
+ )
+ sizes.append(get_npy_vec_size(filename))
+
+ return array_names, sizes
+
+
+def get_header_filename(use_case_filename):
+ """
+ Get the header file name from the use case file name
+
+ @param use_case_filename: The use case file name
+ @return: The header file name
+ """
+ return "TestData" + use_case_filename + ".hpp"
+
+
+def get_cc_filename(use_case_filename):
+ """
+ Get the cc file name from the use case file name
+
+ @param use_case_filename: The use case file name
+ @return: The cc file name
+ """
+ return "TestData" + use_case_filename + ".cc"
+
+
+def main(args):
+ """
+ Generate test data
+ @param args: Parsed args
+ """
add_usecase_fname = ("_" + args.usecase) if (args.usecase != "") else ""
- header_filename = "TestData" + add_usecase_fname + ".hpp"
- common_cc_filename = "TestData" + add_usecase_fname + ".cc"
+ header_filename = get_header_filename(add_usecase_fname)
+ common_cc_filename = get_cc_filename(add_usecase_fname)
# In the data_folder_path there should be pairs of ifm-ofm
# It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy
ifms_count = int(len(list(Path(args.data_folder_path).glob('ifm*.npy'))))
ofms_count = int(len(list(Path(args.data_folder_path).glob('ofm*.npy'))))
- #i_ofms_count = int(len([name for name in os.listdir(os.path.join(args.data_folder_path)) if name.lower().endswith('.npy')]) / 2)
-
iofm_data_type = "int8_t"
if ifms_count > 0:
- iofm_data_type = "int8_t" if (np.load(Path(args.data_folder_path) / "ifm0.npy").dtype == np.int8) else "uint8_t"
-
- ifm_sizes = []
- ofm_sizes = []
+ iofm_data_type = "int8_t" \
+ if (np.load(str(Path(args.data_folder_path) / "ifm0.npy")).dtype == np.int8) \
+ else "uint8_t"
- for idx in range(ifms_count):
- # Save the fm cc file
- base_name = "ifm" + str(idx)
- filename = base_name+".npy"
- array_name = base_name + add_usecase_fname
- cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
- ifm_array_names.append(array_name)
- write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type)
- ifm_sizes.append(get_npy_vec_size(filename))
+ ifm_array_names, ifm_sizes = write_cc_files(
+ args, ifms_count, iofm_data_type, add_usecase_fname, prefix="ifm"
+ )
- for idx in range(ofms_count):
- # Save the fm cc file
- base_name = "ofm" + str(idx)
- filename = base_name+".npy"
- array_name = base_name + add_usecase_fname
- cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
- ofm_array_names.append(array_name)
- write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type)
- ofm_sizes.append(get_npy_vec_size(filename))
+ ofm_array_names, ofm_sizes = write_cc_files(
+ args, ofms_count, iofm_data_type, add_usecase_fname, prefix="ofm"
+ )
common_cc_filepath = Path(args.source_folder_path) / common_cc_filename
- write_hpp_file(header_filename, common_cc_filepath, args.license_template,
- ifms_count, ofms_count, ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type)
+
+ template_params = TestDataParams(
+ ifm_count=ifms_count,
+ ofm_count=ofms_count,
+ ifm_var_names=ifm_array_names,
+ ifm_var_sizes=ifm_sizes,
+ ofm_var_names=ofm_array_names,
+ ofm_var_sizes=ofm_sizes,
+ data_type=iofm_data_type,
+ )
+
+ write_hpp_file(
+ template_params, header_filename, common_cc_filepath, args.license_template
+ )
if __name__ == '__main__':
- if args.verbosity:
- print("Running gen_test_data_cpp with args: "+str(args))
- main(args)
+ if parsed_args.verbosity:
+ print("Running gen_test_data_cpp with args: " + str(parsed_args))
+ main(parsed_args)
diff --git a/scripts/py/gen_utils.py b/scripts/py/gen_utils.py
index ee33705..6bb4760 100644
--- a/scripts/py/gen_utils.py
+++ b/scripts/py/gen_utils.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,21 +14,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import soundfile as sf
-import resampy
+"""
+Utility functions for .cc + .hpp file generation
+"""
+import argparse
+import datetime
+from dataclasses import dataclass
+from pathlib import Path
+
+import jinja2
import numpy as np
+import resampy
+import soundfile as sf
-class AudioUtils:
+@dataclass
+class AudioSample:
+ """
+ Represents an audio sample with its sample rate
+ """
+ data: np.ndarray
+ sample_rate: int
+
+
+class GenUtils:
+ """
+ Class with utility functions for audio and other .cc + .hpp file generation
+ """
+
@staticmethod
def res_data_type(res_type_value):
"""
Returns the input string if is one of the valid resample type
"""
- import argparse
- if res_type_value not in AudioUtils.res_type_list():
- raise argparse.ArgumentTypeError(f"{res_type_value} not valid. Supported only {AudioUtils.res_type_list()}")
+ if res_type_value not in GenUtils.res_type_list():
+ raise argparse.ArgumentTypeError(
+ f"{res_type_value} not valid. Supported only {GenUtils.res_type_list()}"
+ )
return res_type_value
@staticmethod
@@ -39,27 +61,18 @@ class AudioUtils:
return ['kaiser_best', 'kaiser_fast']
@staticmethod
- def load_resample_audio_clip(path, target_sr=16000, mono=True, offset=0.0, duration=0, res_type='kaiser_best',
- min_len=16000):
+ def read_audio_file(
+ path,
+ offset,
+ duration
+ ) -> AudioSample:
"""
- Load and resample an audio clip with the given desired specs.
+ Reads an audio file to an array
- Parameters:
- ----------
- path (string): Path to the input audio clip.
- target_sr (int, optional): Target sampling rate. Positive number are considered valid,
- if zero or negative the native sampling rate of the file will be preserved. Default is 16000.
- mono (bool, optional): Specify if the audio file needs to be converted to mono. Default is True.
- offset (float, optional): Target sampling rate. Default is 0.0.
- duration (int, optional): Target duration. Positive number are considered valid,
- if zero or negative the duration of the file will be preserved. Default is 0.
- res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
- min_len (int, optional): Minimun lenght of the output audio time series. Default is 16000.
-
- Returns:
- ----------
- y (np.ndarray): Output audio time series of shape shape=(n,) or (2, n).
- sr (int): A scalar number > 0 that represent the sampling rate of `y`
+ @param path: Path to audio file
+ @param offset: Offset to read from
+ @param duration: Duration to read
+ @return: The audio data and the sample rate
"""
try:
with sf.SoundFile(path) as audio_file:
@@ -76,40 +89,115 @@ class AudioUtils:
# Load the target number of frames
y = audio_file.read(frames=num_frame_duration, dtype=np.float32, always_2d=False).T
-
- except:
+ except OSError as err:
print(f"Failed to open {path} as an audio.")
+ raise err
+
+ return AudioSample(y, origin_sr)
+
+ @staticmethod
+ def _resample_audio(
+ y,
+ target_sr,
+ origin_sr,
+ res_type
+ ):
+ """
+ Resamples audio to a different sample rate
+
+ @param y: Audio to resample
+ @param target_sr: Target sample rate
+ @param origin_sr: Original sample rate
+ @param res_type: Resample type
+ @return: The resampled audio
+ """
+ ratio = float(target_sr) / origin_sr
+ axis = -1
+ n_samples = int(np.ceil(y.shape[axis] * ratio))
+
+ # Resample using resampy
+ y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
+ n_rs_samples = y_rs.shape[axis]
+
+ # Adjust the size
+ if n_rs_samples > n_samples:
+ slices = [slice(None)] * y_rs.ndim
+ slices[axis] = slice(0, n_samples)
+ y = y_rs[tuple(slices)]
+ elif n_rs_samples < n_samples:
+ lengths = [(0, 0)] * y_rs.ndim
+ lengths[axis] = (0, n_samples - n_rs_samples)
+ y = np.pad(y_rs, lengths, 'constant', constant_values=0)
+
+ return y
+
+ @staticmethod
+ def resample_audio_clip(
+ audio_sample: AudioSample,
+ target_sr=16000,
+ mono=True,
+ res_type='kaiser_best',
+ min_len=16000
+ ) -> AudioSample:
+ """
+ Load and resample an audio clip with the given desired specs.
+
+ Parameters:
+ ----------
+ path (string): Path to the input audio clip.
+ target_sr (int, optional): Target sampling rate. Positive number are considered valid,
+ if zero or negative the native sampling rate of the file
+ will be preserved. Default is 16000.
+ mono (bool, optional): Specify if the audio file needs to be converted to mono.
+ Default is True.
+ offset (float, optional): Target sampling rate. Default is 0.0.
+ duration (int, optional): Target duration. Positive number are considered valid,
+ if zero or negative the duration of the file
+ will be preserved. Default is 0.
+ res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
+ min_len (int, optional): Minimum length of the output audio time series.
+ Default is 16000.
+
+ Returns:
+ ----------
+ y (np.ndarray): Output audio time series of shape=(n,) or (2, n).
+ sample_rate (int): A scalar number > 0 that represent the sampling rate of `y`
+ """
+ y = audio_sample.data.copy()
# Convert to mono if requested and if audio has more than one dimension
- if mono and (y.ndim > 1):
+ if mono and (audio_sample.data.ndim > 1):
y = np.mean(y, axis=0)
- if not (origin_sr == target_sr) and (target_sr > 0):
- ratio = float(target_sr) / origin_sr
- axis = -1
- n_samples = int(np.ceil(y.shape[axis] * ratio))
-
- # Resample using resampy
- y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
- n_rs_samples = y_rs.shape[axis]
-
- # Adjust the size
- if n_rs_samples > n_samples:
- slices = [slice(None)] * y_rs.ndim
- slices[axis] = slice(0, n_samples)
- y = y_rs[tuple(slices)]
- elif n_rs_samples < n_samples:
- lengths = [(0, 0)] * y_rs.ndim
- lengths[axis] = (0, n_samples - n_rs_samples)
- y = np.pad(y_rs, lengths, 'constant', constant_values=(0))
-
- sr = target_sr
+ if not (audio_sample.sample_rate == target_sr) and (target_sr > 0):
+ y = GenUtils._resample_audio(y, target_sr, audio_sample.sample_rate, res_type)
+ sample_rate = target_sr
else:
- sr = origin_sr
+ sample_rate = audio_sample.sample_rate
# Pad if necessary and min lenght is setted (min_len> 0)
if (y.shape[0] < min_len) and (min_len > 0):
sample_to_pad = min_len - y.shape[0]
- y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=(0))
+ y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=0)
+
+ return AudioSample(data=y, sample_rate=sample_rate)
- return y, sr
+ @staticmethod
+ def gen_header(
+ env: jinja2.Environment,
+ header_template_file: str,
+ file_name: str = None
+ ) -> str:
+ """
+ Generate common licence header
+
+ :param env: Jinja2 environment
+ :param header_template_file: Path to the licence header template
+ :param file_name: Optional generating script file name
+ :return: Generated licence header as a string
+ """
+ header_template = env.get_template(header_template_file)
+ return header_template.render(script_name=Path(__file__).name,
+ gen_time=datetime.datetime.now(),
+ file_name=file_name,
+ year=datetime.datetime.now().year)
diff --git a/scripts/py/git_pre_push_hooks.sh b/scripts/py/git_pre_push_hooks.sh
new file mode 100755
index 0000000..db5706f
--- /dev/null
+++ b/scripts/py/git_pre_push_hooks.sh
@@ -0,0 +1,48 @@
+#!/bin/sh
+# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Called by "git push" with no arguments. The hook should
+# exit with non-zero status after issuing an appropriate message if
+# it wants to stop the push.
+
+# shellcheck disable=SC2034,SC2162
+while read local_ref local_sha remote_ref remote_sha; do
+ # We should pass only added or modified C/C++ source files to cppcheck.
+ changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2)
+ if [ -n "$changed_files" ]; then
+ # shellcheck disable=SC2086
+ clang-format -style=file --dry-run --Werror $changed_files
+
+ exitcode1=$?
+ if [ $exitcode1 -ne 0 ]; then
+ echo "Formatting errors found in file: $changed_files. \
+ Please run:
+ \"clang-format -style=file -i $changed_files\"
+ to correct these errors"
+ exit $exitcode1
+ fi
+
+ # shellcheck disable=SC2086
+ cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files
+ exitcode2=$?
+ if [ $exitcode2 -ne 0 ]; then
+ exit $exitcode2
+ fi
+ fi
+ exit 0
+done
+
+exit 0
diff --git a/scripts/py/rnnoise_dump_extractor.py b/scripts/py/rnnoise_dump_extractor.py
index 715b922..9e6ff1f 100644
--- a/scripts/py/rnnoise_dump_extractor.py
+++ b/scripts/py/rnnoise_dump_extractor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,47 +20,84 @@ Example use:
python rnnoise_dump_extractor.py --dump_file output.bin --output_dir ./denoised_wavs/
"""
-import soundfile as sf
-import numpy as np
-
import argparse
-from os import path
import struct
+import typing
+from os import path
+
+import numpy as np
+import soundfile as sf
-def extract(fp, output_dir, export_npy):
+def extract(
+ dump_file: typing.IO,
+ output_dir: str,
+ export_npy: bool
+):
+ """
+ Extract audio file from RNNoise output dump
+
+ @param dump_file: Audio dump file location
+ @param output_dir: Output direction
+ @param export_npy: Whether to export the audio as .npy
+ """
while True:
- filename_length = struct.unpack("i", fp.read(4))[0]
+ filename_length = struct.unpack("i", dump_file.read(4))[0]
if filename_length == -1:
return
- filename = struct.unpack("{}s".format(filename_length), fp.read(filename_length))[0].decode('ascii')
- audio_clip_length = struct.unpack("I", fp.read(4))[0]
- output_file_name = path.join(output_dir, "denoised_{}".format(filename))
- audio_clip = fp.read(audio_clip_length)
-
- with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16", endian="LITTLE") as wav_file:
+ filename = struct \
+ .unpack(f"{filename_length}s", dump_file.read(filename_length))[0] \
+ .decode('ascii')
+
+ audio_clip_length = struct.unpack("I", dump_file.read(4))[0]
+ output_file_name = path.join(output_dir, f"denoised_{filename}")
+ audio_clip = dump_file.read(audio_clip_length)
+
+ with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16",
+ endian="LITTLE") as wav_file:
wav_file.buffer_write(audio_clip, dtype='int16')
- print("{} written to disk".format(output_file_name))
+ print(f"{output_file_name} written to disk")
if export_npy:
output_file_name += ".npy"
- pack_format = "{}h".format(int(audio_clip_length/2))
+ pack_format = f"{int(audio_clip_length / 2)}h"
npdata = np.array(struct.unpack(pack_format, audio_clip)).astype(np.int16)
np.save(output_file_name, npdata)
- print("{} written to disk".format(output_file_name))
+ print(f"{output_file_name} written to disk")
def main(args):
+ """
+ Run RNNoise audio dump extraction
+ @param args: Parsed args
+ """
extract(args.dump_file, args.output_dir, args.export_npy)
parser = argparse.ArgumentParser()
-parser.add_argument("--dump_file", type=argparse.FileType('rb'), help="Dump file with audio files to extract.", required=True)
-parser.add_argument("--output_dir", help="Output directory, Warning: Duplicated file names will be overwritten.", required=True)
-parser.add_argument("--export_npy", help="Export the audio buffer in NumPy format", action="store_true")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--dump_file",
+ type=argparse.FileType('rb'),
+ help="Dump file with audio files to extract.",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory, Warning: Duplicated file names will be overwritten.",
+ required=True
+)
+
+parser.add_argument(
+ "--export_npy",
+ help="Export the audio buffer in NumPy format",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
if __name__ == "__main__":
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/setup_hooks.py b/scripts/py/setup_hooks.py
index ead5e1f..dc3156c 100644
--- a/scripts/py/setup_hooks.py
+++ b/scripts/py/setup_hooks.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
-# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,84 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import os
-import sys
+"""
+Adds the git hooks script into the appropriate location
+"""
import argparse
+import os
+import shutil
import subprocess
-import stat
+import sys
+from pathlib import Path
+
+HOOKS_SCRIPT = "git_pre_push_hooks.sh"
+
-def set_hooks_dir(hooks_dir):
- command = 'git config core.hooksPath {}'.format(hooks_dir)
- subprocess.Popen(command.split(), stdout=subprocess.PIPE)
+def set_hooks_dir(hooks_dir: str):
+ """
+ Set the hooks path in the git configuration
+ @param hooks_dir: The hooks directory
+ """
+ command = f'git config core.hooksPath {hooks_dir}'
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE) as process:
+ process.communicate()
+ return_code = process.returncode
-def add_pre_push_hooks(hooks_dir):
+ if return_code != 0:
+ raise RuntimeError(f"Could not configure git hooks path, exited with code {return_code}")
+
+
+def add_pre_push_hooks(hooks_dir: str):
+ """
+ Copies the git hooks scripts into the specified location
+ @param hooks_dir: The specified git hooks directory
+ """
pre_push = "pre-push"
file_path = os.path.join(hooks_dir, pre_push)
file_exists = os.path.exists(file_path)
if file_exists:
os.remove(file_path)
- f = open(file_path, "a")
- f.write(
-'''#!/bin/sh
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Called by "git push" with no arguments. The hook should
-# exit with non-zero status after issuing an appropriate message if
-# it wants to stop the push.
-
-while read local_ref local_sha remote_ref remote_sha
-do
- # We should pass only added or modified C/C++ source files to cppcheck.
- changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2)
- if [ -n "$changed_files" ]; then
- clang-format -style=file --dry-run --Werror $changed_files
-
- exitcode1=$?
- if [ $exitcode1 -ne 0 ]; then
- echo "Formatting errors found in file: $changed_files.
- \nPlease run:\n\ \"clang-format -style=file -i $changed_files\"
- \nto correct these errors"
- exit $exitcode1
- fi
-
- cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files
- exitcode2=$?
- if [ $exitcode2 -ne 0 ]; then
- exit $exitcode2
- fi
- fi
- exit 0
-done
-exit 0'''
-)
+ script_path = Path(__file__).resolve().parent / HOOKS_SCRIPT
+ shutil.copy(script_path, hooks_dir)
- f.close()
- s = os.stat(file_path)
- os.chmod(file_path, s.st_mode | stat.S_IEXEC)
-parser = argparse.ArgumentParser()
-parser.add_argument("git_hooks_path")
-args = parser.parse_args()
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("git_hooks_path")
+ args = parser.parse_args()
-dir_exists = os.path.exists(args.git_hooks_path)
-if not dir_exists:
- print('Error! The Git hooks directory you supplied does not exist.')
- sys.exit()
+ if not os.path.exists(args.git_hooks_path):
+ print('Error! The Git hooks directory you supplied does not exist.')
+ sys.exit()
-add_pre_push_hooks(args.git_hooks_path)
-set_hooks_dir(args.git_hooks_path)
+ add_pre_push_hooks(args.git_hooks_path)
+ set_hooks_dir(args.git_hooks_path)
diff --git a/scripts/py/templates/header_template.txt b/scripts/py/templates/header_template.txt
index f6e3bdb..32bf71a 100644
--- a/scripts/py/templates/header_template.txt
+++ b/scripts/py/templates/header_template.txt
@@ -16,6 +16,6 @@
*/
/********************* Autogenerated file. DO NOT EDIT *******************
- * Generated from {{script_name}} tool {% if file_name %}and {{file_name}}{% endif %} file.
+ * Generated from {{script_name}} tool {% if file_name %}and {{file_name}} {% endif %}file.
* Date: {{gen_time}}
***************************************************************************/
diff --git a/scripts/py/use_case_resources.json b/scripts/py/use_case_resources.json
new file mode 100644
index 0000000..80fa28d
--- /dev/null
+++ b/scripts/py/use_case_resources.json
@@ -0,0 +1,190 @@
+[
+ {
+ "name": "ad",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "ad_medium_int8.tflite",
+ "url": "{url_prefix:0}ad_medium_int8.tflite"
+ },
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}
+ ]
+ },
+ {
+ "name": "asr",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/"
+ ],
+ "resources": [
+ {
+ "name": "wav2letter_pruned_int8.tflite",
+ "url": "{url_prefix:0}wav2letter_pruned_int8.tflite"
+ },
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/input_2_int8/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "img_class",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "mobilenet_v2_1.0_224_INT8.tflite",
+ "url": "{url_prefix:0}mobilenet_v2_1.0_224_INT8.tflite"
+ },
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/tfl.quantize/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/MobilenetV2/Predictions/Reshape_11/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "object_detection",
+ "url_prefix": [
+ "https://github.com/emza-vs/ModelZoo/blob/v1.0/object_detection/"
+ ],
+ "resources": [
+ {
+ "name": "yolo-fastest_192_face_v4.tflite",
+ "url": "{url_prefix:0}yolo-fastest_192_face_v4.tflite?raw=true"
+ }
+ ]
+ },
+ {
+ "name": "kws",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"},
+ {
+ "name": "kws_micronet_m.tflite",
+ "url": "{url_prefix:0}kws_micronet_m.tflite"
+ }
+ ]
+ },
+ {
+ "name": "vww",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "vww4_128_128_INT8.tflite",
+ "url": "{url_prefix:0}vww4_128_128_INT8.tflite"
+ },
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}
+ ]
+ },
+ {
+ "name": "kws_asr",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/",
+ "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "wav2letter_pruned_int8.tflite",
+ "url": "{url_prefix:0}wav2letter_pruned_int8.tflite"
+ },
+ {
+ "sub_folder": "asr",
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/input_2_int8/0.npy"
+ },
+ {
+ "sub_folder": "asr",
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ },
+ {
+ "sub_folder": "kws",
+ "name": "ifm0.npy",
+ "url": "{url_prefix:1}testing_input/input/0.npy"
+ },
+ {
+ "sub_folder": "kws",
+ "name": "ofm0.npy",
+ "url": "{url_prefix:1}testing_output/Identity/0.npy"
+ },
+ {
+ "name": "kws_micronet_m.tflite",
+ "url": "{url_prefix:1}kws_micronet_m.tflite"
+ }
+ ]
+ },
+ {
+ "name": "noise_reduction",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/"
+ ],
+ "resources": [
+ {"name": "rnnoise_INT8.tflite", "url": "{url_prefix:0}rnnoise_INT8.tflite"},
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/main_input_int8/0.npy"
+ },
+ {
+ "name": "ifm1.npy",
+ "url": "{url_prefix:0}testing_input/vad_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ifm2.npy",
+ "url": "{url_prefix:0}testing_input/noise_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ifm3.npy",
+ "url": "{url_prefix:0}testing_input/denoise_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ },
+ {
+ "name": "ofm1.npy",
+ "url": "{url_prefix:0}testing_output/Identity_1_int8/0.npy"
+ },
+ {
+ "name": "ofm2.npy",
+ "url": "{url_prefix:0}testing_output/Identity_2_int8/0.npy"
+ },
+ {
+ "name": "ofm3.npy",
+ "url": "{url_prefix:0}testing_output/Identity_3_int8/0.npy"
+ },
+ {
+ "name": "ofm4.npy",
+ "url": "{url_prefix:0}testing_output/Identity_4_int8/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "inference_runner",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "dnn_s_quantized.tflite",
+ "url": "{url_prefix:0}dnn_s_quantized.tflite"
+ }
+ ]
+ }
+]