From daba3cf2e3633cbd0e4f8aabe7578b97e88deee1 Mon Sep 17 00:00:00 2001 From: Alex Tawse Date: Fri, 29 Sep 2023 15:55:38 +0100 Subject: MLECO-3995: Pylint + Shellcheck compatibility * All Python scripts updated to abide by Pylint rules * good-names updated to permit short variable names: i, j, k, f, g, ex * ignore-long-lines regex updated to allow long lines for licence headers * Shell scripts now compliant with Shellcheck Signed-off-by: Alex Tawse Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5 --- .pylintrc | 653 +++++++++++++++ build_default.py | 261 ++++-- download_dependencies.py | 96 ++- .../post_training_quantization.py | 61 +- .../quantization_aware_training.py | 68 +- model_conditioning_examples/setup.sh | 9 +- model_conditioning_examples/training_utils.py | 5 +- model_conditioning_examples/weight_clustering.py | 87 +- model_conditioning_examples/weight_pruning.py | 75 +- scripts/py/check_update_resources_downloaded.py | 54 +- scripts/py/dependency_urls.json | 8 + scripts/py/gen_audio.py | 107 ++- scripts/py/gen_audio_cpp.py | 258 ++++-- scripts/py/gen_default_input_cpp.py | 49 +- scripts/py/gen_labels_cpp.py | 74 +- scripts/py/gen_model_cpp.py | 89 +- scripts/py/gen_rgb_cpp.py | 203 +++-- scripts/py/gen_test_data_cpp.py | 317 ++++++-- scripts/py/gen_utils.py | 194 +++-- scripts/py/git_pre_push_hooks.sh | 48 ++ scripts/py/rnnoise_dump_extractor.py | 79 +- scripts/py/setup_hooks.py | 109 +-- scripts/py/templates/header_template.txt | 2 +- scripts/py/use_case_resources.json | 190 +++++ set_up_default_resources.py | 898 +++++++++++---------- 25 files changed, 2914 insertions(+), 1080 deletions(-) create mode 100644 .pylintrc create mode 100644 scripts/py/dependency_urls.json create mode 100755 scripts/py/git_pre_push_hooks.sh create mode 100644 scripts/py/use_case_resources.json diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..b6cb7ee --- /dev/null +++ b/.pylintrc @@ -0,0 +1,653 @@ +# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10.0 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + f, + g, + x, + y, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$|^# SPDX-FileCopyrightText.*$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + use-symbolic-message-instead + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/build_default.py b/build_default.py index 1d562f9..907bf4d 100755 --- a/build_default.py +++ b/build_default.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Script to build the ML Embedded Evaluation kit using default configuration +""" import logging import multiprocessing import os @@ -22,6 +25,7 @@ import sys import threading from argparse import ArgumentDefaultsHelpFormatter from argparse import ArgumentParser +from collections import namedtuple from pathlib import Path from set_up_default_resources import default_npu_config_names @@ -29,66 +33,178 @@ from set_up_default_resources import get_default_npu_config_from_name from set_up_default_resources import set_up_resources from set_up_default_resources import valid_npu_config_names +BuildArgs = namedtuple( + "BuildArgs", + [ + "toolchain", + "download_resources", + "run_vela_on_models", + "npu_config_name", + "make_jobs", + "make_verbose", + ], +) + class PipeLogging(threading.Thread): + """ + Class used to log stdout from subprocesses + """ + def __init__(self, log_level): threading.Thread.__init__(self) - self.logLevel = log_level - self.fileRead, self.fileWrite = os.pipe() - self.pipeIn = os.fdopen(self.fileRead) + self.log_level = log_level + self.file_read, self.file_write = os.pipe() + self.pipe_in = os.fdopen(self.file_read) self.daemon = False self.start() def fileno(self): - return self.fileWrite + """ + Get self.file_write + + Returns + ------- + self.file_write + """ + return self.file_write def run(self): - for line in iter(self.pipeIn.readline, ""): - logging.log(self.logLevel, line.strip("\n")) + """ + Log the contents of self.pipe_in + """ + for line in iter(self.pipe_in.readline, ""): + logging.log(self.log_level, line.strip("\n")) - self.pipeIn.close() + self.pipe_in.close() def close(self): - os.close(self.fileWrite) + """ + Close the pipe + """ + os.close(self.file_write) + + +def get_toolchain_file_name(toolchain: str) -> str: + """ + Get the name of the toolchain file for the toolchain. + + Parameters + ---------- + toolchain : name of the specified toolchain + + Returns + ------- + name of the toolchain file corresponding to the specified + toolchain + """ + if toolchain == "arm": + return "bare-metal-armclang.cmake" + + if toolchain == "gnu": + return "bare-metal-gcc.cmake" + + raise ValueError("Toolchain must be one of: gnu, arm") + + +def prep_build_dir( + current_file_dir: Path, + target_platform: str, + target_subsystem: str, + npu_config_name: str, + toolchain: str +) -> Path: + """ + Create or clean the build directory for this project. + + Parameters + ---------- + current_file_dir : The current directory of the running script + target_platform : The name of the target platform, e.g. "mps3" + target_subsystem: : The name of the target subsystem, e.g. "sse-300" + npu_config_name : The NPU config name, e.g. "ethos-u55-32" + toolchain : The name of the specified toolchain, e.g."arm" + + Returns + ------- + The path to the build directory + """ + build_dir = ( + current_file_dir / + f"cmake-build-{target_platform}-{target_subsystem}-{npu_config_name}-{toolchain}" + ) + + try: + build_dir.mkdir() + except FileExistsError: + # Directory already exists, clean it. + for filepath in build_dir.iterdir(): + try: + if filepath.is_file() or filepath.is_symlink(): + filepath.unlink() + elif filepath.is_dir(): + shutil.rmtree(filepath) + except OSError as err: + logging.error("Failed to delete %s. Reason: %s", filepath, err) + + return build_dir -def run( - toolchain: str, - download_resources: bool, - run_vela_on_models: bool, - npu_config_name: str, - make_jobs: int, - make_verbose: bool, +def run_command( + command: str, + logpipe: PipeLogging, + fail_message: str ): + """ + Run a command and exit upon failure. + + Parameters + ---------- + command : The command to run + logpipe : The PipeLogging object to capture stdout + fail_message : The message to log upon a non-zero exit code + """ + logging.info("\n\n\n%s\n\n\n", command) + + try: + subprocess.run( + command, check=True, shell=True, stdout=logpipe, stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as err: + logging.error(fail_message) + logpipe.close() + sys.exit(err.returncode) + + +def run(args: BuildArgs): """ Run the helpers scripts. Parameters: ---------- - toolchain (str) : Specifies if 'gnu' or 'arm' toolchain needs to be used. - download_resources (bool): Specifies if 'Download resources' step is performed. - run_vela_on_models (bool): Only if `download_resources` is True, specifies if run vela on downloaded models. - npu_config_name(str) : Ethos-U NPU configuration name. See "valid_npu_config_names" + args (BuildArgs) : Parsed set of build args expecting: + - toolchain + - download_resources + - run_vela_on_models + - np_config_name + toolchain (str) : Specifies if 'gnu' or 'arm' toolchain needs to be used. + download_resources (bool) : Specifies if 'Download resources' step is performed. + run_vela_on_models (bool) : Only if `download_resources` is True, specifies if + run vela on downloaded models. + npu_config_name(str) : Ethos-U NPU configuration name. See "valid_npu_config_names" """ current_file_dir = Path(__file__).parent.resolve() # 1. Make sure the toolchain is supported, and set the right one here - supported_toolchain_ids = ["gnu", "arm"] - assert ( - toolchain in supported_toolchain_ids - ), f"Toolchain must be from {supported_toolchain_ids}" - if toolchain == "arm": - toolchain_file_name = "bare-metal-armclang.cmake" - elif toolchain == "gnu": - toolchain_file_name = "bare-metal-gcc.cmake" + toolchain_file_name = get_toolchain_file_name(args.toolchain) # 2. Download models if specified - if download_resources is True: + if args.download_resources is True: logging.info("Downloading resources.") - (download_dir, env_path) = set_up_resources( - run_vela_on_models=run_vela_on_models, - additional_npu_config_names=[npu_config_name], + env_path = set_up_resources( + run_vela_on_models=args.run_vela_on_models, + additional_npu_config_names=(args.npu_config_name,), additional_requirements_file=current_file_dir / "scripts" / "py" / "requirements.txt" ) @@ -96,57 +212,42 @@ def run( logging.info("Building default configuration.") target_platform = "mps3" target_subsystem = "sse-300" - ethos_u_cfg = get_default_npu_config_from_name(npu_config_name) - build_dir = current_file_dir / f"cmake-build-{target_platform}-{target_subsystem}-{npu_config_name}-{toolchain}" - try: - build_dir.mkdir() - except FileExistsError: - # Directory already exists, clean it. - for filepath in build_dir.iterdir(): - try: - if filepath.is_file() or filepath.is_symlink(): - filepath.unlink() - elif filepath.is_dir(): - shutil.rmtree(filepath) - except Exception as e: - logging.error(f"Failed to delete {filepath}. Reason: {e}") + build_dir = prep_build_dir( + current_file_dir, + target_platform, + target_subsystem, + args.npu_config_name, + args.toolchain + ) logpipe = PipeLogging(logging.INFO) - cmake_toolchain_file = current_file_dir / "scripts" / "cmake" / "toolchains" / toolchain_file_name + cmake_toolchain_file = ( + current_file_dir / + "scripts" / + "cmake" / + "toolchains" / + toolchain_file_name + ) + ethos_u_cfg = get_default_npu_config_from_name(args.npu_config_name) cmake_path = env_path / "bin" / "cmake" cmake_command = ( f"{cmake_path} -B {build_dir} -DTARGET_PLATFORM={target_platform}" - + f" -DTARGET_SUBSYSTEM={target_subsystem}" - + f" -DCMAKE_TOOLCHAIN_FILE={cmake_toolchain_file}" - + f" -DETHOS_U_NPU_ID={ethos_u_cfg.ethos_u_npu_id}" - + f" -DETHOS_U_NPU_CONFIG_ID={ethos_u_cfg.ethos_u_config_id}" - + f" -DTENSORFLOW_LITE_MICRO_CLEAN_DOWNLOADS=ON" - ) - - logging.info(f"\n\n\n{cmake_command}\n\n\n") - state = subprocess.run( - cmake_command, shell=True, stdout=logpipe, stderr=subprocess.STDOUT + f" -DTARGET_SUBSYSTEM={target_subsystem}" + f" -DCMAKE_TOOLCHAIN_FILE={cmake_toolchain_file}" + f" -DETHOS_U_NPU_ID={ethos_u_cfg.ethos_u_npu_id}" + f" -DETHOS_U_NPU_CONFIG_ID={ethos_u_cfg.ethos_u_config_id}" + " -DTENSORFLOW_LITE_MICRO_CLEAN_DOWNLOADS=ON" ) - if state.returncode != 0: - logging.error("Failed to configure the project.") - logpipe.close() - sys.exit(state.returncode) + run_command(cmake_command, logpipe, fail_message="Failed to configure the project.") - make_command = f"{cmake_path} --build {build_dir} -j{make_jobs}" - if make_verbose: + make_command = f"{cmake_path} --build {build_dir} -j{args.make_jobs}" + if args.make_verbose: make_command += " --verbose" - logging.info(f"\n\n\n{make_command}\n\n\n") - state = subprocess.run( - make_command, shell=True, stdout=logpipe, stderr=subprocess.STDOUT - ) - if state.returncode != 0: - logging.error("Failed to build project.") - logpipe.close() - sys.exit(state.returncode) + run_command(make_command, logpipe, fail_message="Failed to build project.") logpipe.close() @@ -185,18 +286,20 @@ if __name__ == "__main__": parser.add_argument( "--make-verbose", help="Make runs with VERBOSE=1", action="store_true" ) - args = parser.parse_args() + parsed_args = parser.parse_args() logging.basicConfig( filename="log_build_default.log", level=logging.DEBUG, filemode="w" ) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - run( - args.toolchain.lower(), - not args.skip_download, - not args.skip_vela, - args.npu_config_name, - args.make_jobs, - args.make_verbose, + build_args = BuildArgs( + toolchain=parsed_args.toolchain.lower(), + download_resources=not parsed_args.skip_download, + run_vela_on_models=not parsed_args.skip_vela, + npu_config_name=parsed_args.npu_config_name, + make_jobs=parsed_args.make_jobs, + make_verbose=parsed_args.make_verbose ) + + run(build_args) diff --git a/download_dependencies.py b/download_dependencies.py index 3934f94..786fa4c 100755 --- a/download_dependencies.py +++ b/download_dependencies.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,69 +15,92 @@ # limitations under the License. """This script does effectively the same as "git submodule update --init" command.""" +import json import logging import sys import tarfile import tempfile +import typing +from pathlib import Path from urllib.request import urlopen from zipfile import ZipFile -from pathlib import Path -TF = "https://github.com/tensorflow/tflite-micro/archive/568d181ccc1f60e49742fd43b7f97141ee8d45fc.zip" -CMSIS = "https://github.com/ARM-software/CMSIS_5/archive/a75f01746df18bb5b929dfb8dc6c9407fac3a0f3.zip" -CMSIS_DSP = "https://github.com/ARM-software/CMSIS-DSP/archive/refs/tags/v1.15.0.zip" -CMSIS_NN = "https://github.com/ARM-software/CMSIS-NN/archive/refs/85164a811917770d7027a12a57ed3b469dac6537.zip" -ETHOS_U_CORE_DRIVER = "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/snapshot/ethos-u-core-driver-23.08.tar.gz" -ETHOS_U_CORE_PLATFORM = "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-23.08.tar.gz" +def download( + url_file: str, + to_path: Path, +): + """ + Download a file from the specified URL -def download(url_file: str, post_process=None): + @param url_file: The URL of the file to download + @param to_path: The location to download the file to + """ with urlopen(url_file) as response, tempfile.NamedTemporaryFile() as temp: - logging.info(f"Downloading {url_file} ...") + logging.info("Downloading %s ...", url_file) temp.write(response.read()) temp.seek(0) - logging.info(f"Finished downloading {url_file}.") - if post_process: - post_process(temp) - - -def unzip(file, to_path): - with ZipFile(file) as z: - for archive_path in z.infolist(): + logging.info("Finished downloading %s.", url_file) + if url_file.endswith(".tar.gz"): + untar(temp, to_path) + else: + unzip(temp, to_path) + + +def unzip( + file: typing.IO[bytes], + to_path: Path +): + """ + Unzip the specified file + + @param file: The file to unzip + @param to_path: The location to extract to + """ + with ZipFile(file) as f: + for archive_path in f.infolist(): archive_path.filename = archive_path.filename[archive_path.filename.find("/") + 1:] if archive_path.filename: - z.extract(archive_path, to_path) + f.extract(archive_path, to_path) target_path = to_path / archive_path.filename attr = archive_path.external_attr >> 16 if attr != 0: target_path.chmod(attr) -def untar(file, to_path): - with tarfile.open(file) as z: - for archive_path in z.getmembers(): +def untar( + file: bytes, + to_path: Path +): + """ + Untar the specified file + + @param file: The file to untar + @param to_path: The location to extract to + """ + with tarfile.open(file) as f: + for archive_path in f.getmembers(): index = archive_path.name.find("/") if index < 0: continue archive_path.name = archive_path.name[index + 1:] if archive_path.name: - z.extract(archive_path, to_path) + f.extract(archive_path, to_path) def main(dependencies_path: Path): + """ + Download all dependencies + + @param dependencies_path: The path to which the dependencies will be downloaded + """ + dependency_urls_path = ( + Path(__file__).parent.resolve() / "scripts" / "py" / "dependency_urls.json") + with open(dependency_urls_path, encoding="utf8") as f: + dependency_urls = json.load(f) - download(CMSIS, - lambda file: unzip(file.name, to_path=dependencies_path / "cmsis")) - download(CMSIS_DSP, - lambda file: unzip(file.name, to_path=dependencies_path / "cmsis-dsp")) - download(CMSIS_NN, - lambda file: unzip(file.name, to_path=dependencies_path / "cmsis-nn")) - download(ETHOS_U_CORE_DRIVER, - lambda file: untar(file.name, to_path=dependencies_path / "core-driver")) - download(ETHOS_U_CORE_PLATFORM, - lambda file: untar(file.name, to_path=dependencies_path / "core-platform")) - download(TF, - lambda file: unzip(file.name, to_path=dependencies_path / "tensorflow")) + for name, url in dependency_urls.items(): + download(url, dependencies_path / name) if __name__ == '__main__': @@ -88,6 +110,6 @@ if __name__ == '__main__': download_dir = Path(__file__).parent.resolve() / "dependencies" if download_dir.is_dir(): - logging.info(f'{download_dir} exists. Skipping download.') + logging.info('%s exists. Skipping download.', download_dir) else: main(download_dir) diff --git a/model_conditioning_examples/post_training_quantization.py b/model_conditioning_examples/post_training_quantization.py index a39be0e..42069f5 100644 --- a/model_conditioning_examples/post_training_quantization.py +++ b/model_conditioning_examples/post_training_quantization.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,28 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This script will provide you with an example of how to perform post-training quantization in TensorFlow. +This script will provide you with an example of how to perform +post-training quantization in TensorFlow. -The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit -integer values. +The output from this example will be a TensorFlow Lite model file +where weights and activations are quantized to 8bit integer values. -Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm -Ethos NPU. +Quantization helps reduce the size of your models and is necessary +for running models on certain hardware such as Arm Ethos NPU. -In addition to quantizing weights, post-training quantization uses a calibration dataset to -capture the minimum and maximum values of all variable tensors in your model. -By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations. +In addition to quantizing weights, post-training quantization uses +a calibration dataset to capture the minimum and maximum values of +all variable tensors in your model. By capturing these ranges it +is possible to fully quantize not just the weights of the model +but also the activations. -Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should -be minimal. +Depending on the model you are quantizing there may be some accuracy loss, +but for a lot of models the loss should be minimal. -If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +If you are targeting an Arm Ethos-U55 NPU then the output +TensorFlow Lite file will also need to be passed through the Vela compiler for further optimizations before it can be used. -For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ -For more information on post-training quantization -see: https://www.tensorflow.org/lite/performance/post_training_integer_quant +For more information on using Vela see: + https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on post-training quantization see: + https://www.tensorflow.org/lite/performance/post_training_integer_quant """ + import pathlib import numpy as np @@ -44,7 +50,8 @@ from training_utils import get_data, create_model def post_training_quantize(keras_model, sample_data): - """Quantize Keras model using post-training quantization with some sample data. + """ + Quantize Keras model using post-training quantization with some sample data. TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing. @@ -76,8 +83,14 @@ def post_training_quantize(keras_model, sample_data): return tflite_model -def evaluate_tflite_model(tflite_save_path, x_test, y_test): - """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter. +# pylint: disable=duplicate-code +def evaluate_tflite_model( + tflite_save_path: pathlib.Path, + x_test: np.ndarray, + y_test: np.ndarray +): + """ + Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter. Args: tflite_save_path: Path to TensorFlow Lite model to test. @@ -106,6 +119,9 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test): def main(): + """ + Run post-training quantization + """ x_train, y_train, x_test, y_test = get_data() model = create_model() @@ -117,7 +133,7 @@ def main(): model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) # Test the fp32 model accuracy. - test_loss, test_acc = model.evaluate(x_test, y_test) + test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable print(f"Test accuracy float: {test_acc:.3f}") # Quantize and export the resulting TensorFlow Lite model to file. @@ -132,7 +148,12 @@ def main(): # Test the quantized model accuracy. Save time by only testing a subset of the whole data. num_test_samples = 1000 - evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + evaluate_tflite_model( + quant_model_save_path, + x_test[0:num_test_samples], + y_test[0:num_test_samples] + ) +# pylint: enable=duplicate-code if __name__ == "__main__": diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py index 3d492a7..d590763 100644 --- a/model_conditioning_examples/quantization_aware_training.py +++ b/model_conditioning_examples/quantization_aware_training.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,31 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the +This script will provide you with a short example of how to perform +quantization aware training in TensorFlow using the TensorFlow Model Optimization Toolkit. -The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit -integer values. +The output from this example will be a TensorFlow Lite model file +where weights and activations are quantized to 8bit integer values. -Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm -Ethos NPU. +Quantization helps reduce the size of your models and is necessary +for running models on certain hardware such as Arm Ethos NPU. -In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using -fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted -weights and minimize accuracy losses caused by the reduced precision. +In quantization aware training (QAT), the error introduced with +quantizing from fp32 to int8 is simulated using fake quantization nodes. +By simulating this quantization error when training, +the model can learn better adapted weights and minimize accuracy losses +caused by the reduced precision. -Minimum and maximum values for activations are also captured during training so activations for every layer can be -quantized along with the weights later. +Minimum and maximum values for activations are also captured +during training so activations for every layer can be quantized +along with the weights later. -Quantization is only simulated during training and the training backward passes are still performed in full float -precision. Actual quantization happens when generating a TensorFlow Lite model. +Quantization is only simulated during training and the +training backward passes are still performed in full float precision. +Actual quantization happens when generating a TensorFlow Lite model. -If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +If you are targeting an Arm Ethos-U55 NPU then the output +TensorFlow Lite file will also need to be passed through the Vela compiler for further optimizations before it can be used. -For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ -For more information on quantization aware training -see: https://www.tensorflow.org/model_optimization/guide/quantization/training +For more information on using vela see: + https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on quantization aware training see: + https://www.tensorflow.org/model_optimization/guide/quantization/training """ import pathlib @@ -64,13 +71,15 @@ def quantize_and_convert_to_tflite(keras_model): # After doing quantization aware training all the information for creating a fully quantized # TensorFlow Lite model is already within the quantization aware Keras model. - # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model. + # This means we only need to call convert with default optimizations to + # generate the quantized TensorFlow Lite model. converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() return tflite_model +# pylint: disable=duplicate-code def evaluate_tflite_model(tflite_save_path, x_test, y_test): """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter. @@ -101,13 +110,19 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test): def main(): + """ + Run quantization aware training + """ x_train, y_train, x_test, y_test = get_data() model = create_model() - # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our - # model quantization aware in one line. Once this is done we compile the model and train as normal. - # It is important to note that the model is only quantization aware and is not quantized yet. The weights are - # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on. + # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit + # we can make our model quantization aware in one line. + # Once this is done we compile the model and train as normal. + # It is important to note that the model is only quantization aware + # and is not quantized yet. + # The weights are still floating point and will only be converted + # to int8 when we generate the TensorFlow Lite model later on. quant_aware_model = tfmot.quantization.keras.quantize_model(model) quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), @@ -117,7 +132,7 @@ def main(): quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) # Test the quantization aware model accuracy. - test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) + test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) # pylint: disable=unused-variable print(f"Test accuracy quant aware: {test_acc:.3f}") # Quantize and save the resulting TensorFlow Lite model to file. @@ -132,7 +147,12 @@ def main(): # Test quantized model accuracy. Save time by only testing a subset of the whole data. num_test_samples = 1000 - evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + evaluate_tflite_model( + quant_model_save_path, + x_test[0:num_test_samples], + y_test[0:num_test_samples] + ) +# pylint: enable=duplicate-code if __name__ == "__main__": diff --git a/model_conditioning_examples/setup.sh b/model_conditioning_examples/setup.sh index 92de78a..678f9d3 100644 --- a/model_conditioning_examples/setup.sh +++ b/model_conditioning_examples/setup.sh @@ -1,5 +1,7 @@ +#!/bin/bash + #---------------------------------------------------------------------------- -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +16,9 @@ # See the License for the specific language governing permissions and # limitations under the License. #---------------------------------------------------------------------------- -#!/bin/bash + python3 -m venv ./env +# shellcheck disable=SC1091 source ./env/bin/activate pip install -U pip -pip install -r requirements.txt \ No newline at end of file +pip install -r requirements.txt diff --git a/model_conditioning_examples/training_utils.py b/model_conditioning_examples/training_utils.py index a022bd1..2ce94b8 100644 --- a/model_conditioning_examples/training_utils.py +++ b/model_conditioning_examples/training_utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -49,7 +49,8 @@ def create_model(): """ keras_model = tf.keras.models.Sequential([ - tf.keras.layers.Conv2D(32, 3, padding='same', input_shape=(28, 28, 1), activation=tf.nn.relu), + tf.keras.layers.Conv2D(32, 3, padding='same', + input_shape=(28, 28, 1), activation=tf.nn.relu), tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu), tf.keras.layers.MaxPool2D(), tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu), diff --git a/model_conditioning_examples/weight_clustering.py b/model_conditioning_examples/weight_clustering.py index 6672d53..e966336 100644 --- a/model_conditioning_examples/weight_clustering.py +++ b/model_conditioning_examples/weight_clustering.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,22 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This script will provide you with a short example of how to perform clustering of weights (weight sharing) in -TensorFlow using the TensorFlow Model Optimization Toolkit. +This script will provide you with a short example of how to perform +clustering of weights (weight sharing) in TensorFlow +using the TensorFlow Model Optimization Toolkit. -The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into -16 clusters during training - quantization has then been applied on top of this. +The output from this example will be a TensorFlow Lite model file +where weights in each layer have been 'clustered' into 16 clusters +during training - quantization has then been applied on top of this. -By clustering the model we can improve compression of the model file. This can be essential for deploying certain -models on systems with limited resources - such as embedded systems using an Arm Ethos NPU. +By clustering the model we can improve compression of the model file. +This can be essential for deploying certain models on systems with +limited resources - such as embedded systems using an Arm Ethos NPU. -After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file. +After performing clustering we do post-training quantization +to quantize the model and then generate a TensorFlow Lite file. -If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +If you are targeting an Arm Ethos-U55 NPU then the output +TensorFlow Lite file will also need to be passed through the Vela compiler for further optimizations before it can be used. -For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ -For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering +For more information on using Vela see: + https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on clustering see: + https://www.tensorflow.org/model_optimization/guide/clustering """ import pathlib @@ -42,39 +49,52 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m def prepare_for_clustering(keras_model): """Prepares a Keras model for clustering.""" - # Choose the number of clusters to use and how to initialize them. Using more clusters will generally - # reduce accuracy so you will need to find the optimal number for your use-case. + # Choose the number of clusters to use and how to initialize them. + # Using more clusters will generally reduce accuracy, + # so you will need to find the optimal number for your use-case. number_of_clusters = 16 cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR - # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that - # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered. - clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model, - number_of_clusters=number_of_clusters, - cluster_centroids_init=cluster_centroids_init) + # Apply the clustering wrapper to the whole model so weights in + # every layer will get clustered. You may find that to avoid + # too much accuracy loss only certain non-critical layers in + # your model should be clustered. + clustering_ready_model = tfmot.clustering.keras.cluster_weights( + keras_model, + number_of_clusters=number_of_clusters, + cluster_centroids_init=cluster_centroids_init + ) # We must recompile the model after making it ready for clustering. - clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), - loss=tf.keras.losses.sparse_categorical_crossentropy, - metrics=['accuracy']) + clustering_ready_model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy'] + ) return clustering_ready_model def main(): + """ + Run weight clustering + """ x_train, y_train, x_test, y_test = get_data() model = create_model() # Compile and train the model first. - # In general it is easier to do clustering as a fine-tuning step after the model is fully trained. - model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), - loss=tf.keras.losses.sparse_categorical_crossentropy, - metrics=['accuracy']) + # In general, it is easier to do clustering as a + # fine-tuning step after the model is fully trained. + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy'] + ) model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) # Test the trained model accuracy. - test_loss, test_acc = model.evaluate(x_test, y_test) + test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable print(f"Test accuracy before clustering: {test_acc:.3f}") # Prepare the model for clustering. @@ -88,19 +108,26 @@ def main(): # Remove all variables that clustering only needed in the training phase. model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model) - # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file. + # Apply post-training quantization on top of the clustering + # and save the resulting TensorFlow Lite model to file. tflite_model = post_training_quantize(model_for_export, x_train) tflite_models_dir = pathlib.Path('./conditioned_models/') tflite_models_dir.mkdir(exist_ok=True, parents=True) - clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite' + clustered_quant_model_save_path = \ + tflite_models_dir / 'clustered_post_training_quant_model.tflite' with open(clustered_quant_model_save_path, 'wb') as f: f.write(tflite_model) - # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data. + # Test the clustered quantized model accuracy. + # Save time by only testing a subset of the whole data. num_test_samples = 1000 - evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + evaluate_tflite_model( + clustered_quant_model_save_path, + x_test[0:num_test_samples], + y_test[0:num_test_samples] + ) if __name__ == "__main__": diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py index cbf9cf9..303b6df 100644 --- a/model_conditioning_examples/weight_pruning.py +++ b/model_conditioning_examples/weight_pruning.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,23 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow +This script will provide you with a short example of how to perform +magnitude-based weight pruning in TensorFlow using the TensorFlow Model Optimization Toolkit. -The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the +The output from this example will be a TensorFlow Lite model file +where ~75% percent of the weights have been 'pruned' to the value 0 during training - quantization has then been applied on top of this. -By pruning the model we can improve compression of the model file. This can be essential for deploying certain models -on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run -on an Arm Ethos NPU then this pruning can improve the execution time of the model. +By pruning the model we can improve compression of the model file. +This can be essential for deploying certain models on systems +with limited resources - such as embedded systems using Arm Ethos NPU. +Also, if the pruned model is run on an Arm Ethos NPU then +this pruning can improve the execution time of the model. -After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file. +After pruning is complete we do post-training quantization +to quantize the model and then generate a TensorFlow Lite file. -If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +If you are targeting an Arm Ethos-U55 NPU then the output +TensorFlow Lite file will also need to be passed through the Vela compiler for further optimizations before it can be used. -For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ -For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning +For more information on using Vela see: + https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on weight pruning see: + https://www.tensorflow.org/model_optimization/guide/pruning """ import pathlib @@ -43,13 +51,20 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m def prepare_for_pruning(keras_model): """Prepares a Keras model for pruning.""" - # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout - # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training. + # We use a constant sparsity schedule so the amount of sparsity + # in the model is kept at the same percent throughout training. + # An alternative is PolynomialDecay where sparsity + # can be gradually increased during training. pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0) - # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid - # too much accuracy loss only certain non-critical layers in your model should be pruned. - pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule) + # Apply the pruning wrapper to the whole model + # so weights in every layer will get pruned. + # You may find that to avoid too much accuracy loss only + # certain non-critical layers in your model should be pruned. + pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude( + keras_model, + pruning_schedule=pruning_schedule + ) # We must recompile the model after making it ready for pruning. pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), @@ -60,11 +75,15 @@ def prepare_for_pruning(keras_model): def main(): + """ + Run weight pruning + """ x_train, y_train, x_test, y_test = get_data() model = create_model() # Compile and train the model first. - # In general it is easier to do pruning as a fine-tuning step after the model is fully trained. + # In general, it is easier to do pruning as a fine-tuning step + # after the model is fully trained. model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy']) @@ -72,7 +91,7 @@ def main(): model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) # Test the trained model accuracy. - test_loss, test_acc = model.evaluate(x_test, y_test) + test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable print(f"Test accuracy before pruning: {test_acc:.3f}") # Prepare the model for pruning and add the pruning update callback needed in training. @@ -80,14 +99,23 @@ def main(): callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] # Continue training the model but now with pruning applied - remember to pass in the callbacks! - pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks) + pruned_model.fit( + x=x_train, + y=y_train, + batch_size=128, + epochs=1, + verbose=1, + shuffle=True, + callbacks=callbacks + ) test_loss, test_acc = pruned_model.evaluate(x_test, y_test) print(f"Test accuracy after pruning: {test_acc:.3f}") # Remove all variables that pruning only needed in the training phase. model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model) - # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file. + # Apply post-training quantization on top of the pruning + # and save the resulting TensorFlow Lite model to file. tflite_model = post_training_quantize(model_for_export, x_train) tflite_models_dir = pathlib.Path('./conditioned_models/') @@ -97,9 +125,14 @@ def main(): with open(pruned_quant_model_save_path, 'wb') as f: f.write(tflite_model) - # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data. + # Test the pruned quantized model accuracy. + # Save time by only testing a subset of the whole data. num_test_samples = 1000 - evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + evaluate_tflite_model( + pruned_quant_model_save_path, + x_test[0:num_test_samples], + y_test[0:num_test_samples] + ) if __name__ == "__main__": diff --git a/scripts/py/check_update_resources_downloaded.py b/scripts/py/check_update_resources_downloaded.py index 6e4da21..bdd9d62 100644 --- a/scripts/py/check_update_resources_downloaded.py +++ b/scripts/py/check_update_resources_downloaded.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Contains methods to check if the downloaded resources need to be refreshed +""" +import hashlib import json import sys -import hashlib +import typing from argparse import ArgumentParser from pathlib import Path -def get_md5sum_for_file(filepath: str) -> str: +def get_md5sum_for_file( + filepath: typing.Union[str, Path] +) -> str: """ Function to calculate md5sum for contents of a given file. @@ -41,7 +46,7 @@ def get_md5sum_for_file(filepath: str) -> str: def check_update_resources_downloaded( - resource_downloaded_dir: str, set_up_script_path: str + resource_downloaded_dir: str, set_up_script_path: str ): """ Function that check if the resources downloaded need to be refreshed. @@ -55,27 +60,27 @@ def check_update_resources_downloaded( metadata_file_path = Path(resource_downloaded_dir) / "resources_downloaded_metadata.json" if metadata_file_path.is_file(): - with open(metadata_file_path) as metadata_json: - + with open(metadata_file_path, encoding="utf8") as metadata_json: metadata_dict = json.load(metadata_json) - md5_key = 'set_up_script_md5sum' - set_up_script_md5sum_metadata = '' - if md5_key in metadata_dict.keys(): - set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"] + md5_key = 'set_up_script_md5sum' + set_up_script_md5sum_metadata = '' + + if md5_key in metadata_dict.keys(): + set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"] - set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path) + set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path) - if set_up_script_md5sum_current == set_up_script_md5sum_metadata: - return 0 + if set_up_script_md5sum_current == set_up_script_md5sum_metadata: + return 0 - # Return code 1 if the resources need to be refreshed. - print('Error: hash mismatch!') - print(f'Metadata: {set_up_script_md5sum_metadata}') - print(f'Current : {set_up_script_md5sum_current}') - return 1 + # Return code 1 if the resources need to be refreshed. + print('Error: hash mismatch!') + print(f'Metadata: {set_up_script_md5sum_metadata}') + print(f'Current : {set_up_script_md5sum_current}') + return 1 - # Return error code 2 if the file doesn't exists. + # Return error code 2 if the file doesn't exist. print(f'Error: could not find {metadata_file_path}') return 2 @@ -99,7 +104,8 @@ if __name__ == "__main__": raise ValueError(f'Invalid script path: {args.setup_script_path}') # Check the resources are downloaded as expected - status = check_update_resources_downloaded( - args.resource_downloaded_dir, - args.setup_script_path) - sys.exit(status) + STATUS = check_update_resources_downloaded( + args.resource_downloaded_dir, + args.setup_script_path + ) + sys.exit(STATUS) diff --git a/scripts/py/dependency_urls.json b/scripts/py/dependency_urls.json new file mode 100644 index 0000000..33a84f7 --- /dev/null +++ b/scripts/py/dependency_urls.json @@ -0,0 +1,8 @@ +{ + "cmsis": "https://github.com/ARM-software/CMSIS_5/archive/a75f01746df18bb5b929dfb8dc6c9407fac3a0f3.zip", + "cmsis-dsp": "https://github.com/ARM-software/CMSIS-DSP/archive/refs/tags/v1.15.0.zip", + "cmsis-nn": "https://github.com/ARM-software/CMSIS-NN/archive/refs/85164a811917770d7027a12a57ed3b469dac6537.zip", + "core-driver": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/snapshot/ethos-u-core-driver-23.08.tar.gz", + "core-platform": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-23.08.tar.gz", + "tensorflow": "https://github.com/tensorflow/tflite-micro/archive/568d181ccc1f60e49742fd43b7f97141ee8d45fc.zip" +} diff --git a/scripts/py/gen_audio.py b/scripts/py/gen_audio.py index ff33bfb..4d7318c 100644 --- a/scripts/py/gen_audio.py +++ b/scripts/py/gen_audio.py @@ -1,6 +1,6 @@ #!env/bin/python3 -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,34 +17,99 @@ """ Utility script to convert an audio clip into eval platform desired spec. """ -import soundfile as sf - from argparse import ArgumentParser from os import path -from gen_utils import AudioUtils +import soundfile as sf + +from gen_utils import GenUtils parser = ArgumentParser() -parser.add_argument("--audio_path", help="Audio file path", required=True) -parser.add_argument("--output_dir", help="Output directory", required=True) -parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000) -parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True) -parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0) -parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0) -parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.", default='kaiser_best') -parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000) -parser.add_argument("-v", "--verbosity", action="store_true") -args = parser.parse_args() + +# pylint: disable=duplicate-code +parser.add_argument( + "--audio_path", + help="Audio file path", + required=True +) + +parser.add_argument( + "--output_dir", + help="Output directory", + required=True +) + +parser.add_argument( + "--sampling_rate", + type=int, + help="target sampling rate.", + default=16000 +) + +parser.add_argument( + "--mono", + type=bool, + help="convert signal to mono.", + default=True +) + +parser.add_argument( + "--offset", + type=float, + help="start reading after this time (in seconds).", + default=0 +) + +parser.add_argument( + "--duration", + type=float, + help="only load up to this much audio (in seconds).", + default=0 +) + +parser.add_argument( + "--res_type", + type=GenUtils.res_data_type, + help=f"Resample type: {GenUtils.res_type_list()}.", + default='kaiser_best' +) + +parser.add_argument( + "--min_samples", + type=int, + help="Minimum sample number.", + default=16000 +) + +parser.add_argument( + "-v", + "--verbosity", + action="store_true" +) +# pylint: enable=duplicate-code + +parsed_args = parser.parse_args() def main(args): - audio_data, samplerate = AudioUtils.load_resample_audio_clip(args.audio_path, - args.sampling_rate, - args.mono, args.offset, - args.duration, args.res_type, - args.min_samples) - sf.write(path.join(args.output_dir, path.basename(args.audio_path)), audio_data, samplerate) + """ + Generate the new audio file + @param args: Parsed args + """ + audio_sample = GenUtils.read_audio_file( + args.audio_path, args.offset, args.duration + ) + + resampled_audio = GenUtils.resample_audio_clip( + audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples + ) + + sf.write( + path.join(args.output_dir, path.basename(args.audio_path)), + resampled_audio.data, + resampled_audio.sample_rate + ) if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_audio_cpp.py b/scripts/py/gen_audio_cpp.py index 850a871..89d9ae1 100644 --- a/scripts/py/gen_audio_cpp.py +++ b/scripts/py/gen_audio_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,93 +21,217 @@ from the cpp files. import datetime import glob import math -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path import numpy as np from jinja2 import Environment, FileSystemLoader -from gen_utils import AudioUtils +from gen_utils import GenUtils, AudioSample +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--audio_path", type=str, help="path to audio folder to convert.") -parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.") -parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.") -parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000) -parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True) -parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0) -parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0) -parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.", - default='kaiser_best') -parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000) -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -parser.add_argument("-v", "--verbosity", action="store_true") -args = parser.parse_args() + +parser.add_argument( + "--audio_path", + type=str, + help="path to audio folder to convert." +) + +parser.add_argument( + "--source_folder_path", + type=str, + help="path to source folder to be generated." +) + +parser.add_argument( + "--header_folder_path", + type=str, + help="path to header folder to be generated." +) + +parser.add_argument( + "--sampling_rate", + type=int, + help="target sampling rate.", + default=16000 +) + +parser.add_argument( + "--mono", + type=bool, + help="convert signal to mono.", + default=True +) + +parser.add_argument( + "--offset", + type=float, + help="start reading after this time (in seconds).", + default=0 +) + +parser.add_argument( + "--duration", + type=float, + help="only load up to this much audio (in seconds).", + default=0 +) + +parser.add_argument( + "--res_type", + type=GenUtils.res_data_type, + help=f"Resample type: {GenUtils.res_type_list()}.", + default='kaiser_best' +) + +parser.add_argument( + "--min_samples", + type=int, + help="Minimum sample number.", + default=16000 +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parser.add_argument( + "-v", + "--verbosity", + action="store_true" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) -def write_hpp_file(header_filepath, cc_filepath, header_template_file, num_audios, audio_filenames, audio_array_namesizes): +# pylint: enable=duplicate-code +def write_hpp_file( + header_filepath, + header, + num_audios, + audio_array_namesizes +): + """ + Write audio hpp file + + @param header_filepath: .hpp filepath + @param header: Rendered header + @param num_audios: Audio file index + @param audio_array_namesizes: Audio array name sizes + """ print(f"++ Generating {header_filepath}") - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) - env.get_template('AudioClips.hpp.template').stream(common_template_header=hdr, - clips_count=num_audios, - varname_size=audio_array_namesizes - ) \ + env \ + .get_template('AudioClips.hpp.template') \ + .stream(common_template_header=header, + clips_count=num_audios, + varname_size=audio_array_namesizes) \ .dump(str(header_filepath)) + +def write_cc_file( + cc_filepath, + header, + num_audios, + audio_filenames, + audio_array_namesizes +): + """ + Write cc file + + @param cc_filepath: .cc filepath + @param header: Rendered header + @param num_audios: Audio file index + @param audio_filenames: Audio filenames + @param audio_array_namesizes: Audio array name sizes + """ print(f"++ Generating {cc_filepath}") - env.get_template('AudioClips.cc.template').stream(common_template_header=hdr, - clips_count=num_audios, - var_names=(name for name, _ in audio_array_namesizes), - clip_sizes=(size for _, size in audio_array_namesizes), - clip_names=audio_filenames) \ + env \ + .get_template('AudioClips.cc.template') \ + .stream(common_template_header=header, + clips_count=num_audios, + var_names=(name for name, _ in audio_array_namesizes), + clip_sizes=(size for _, size in audio_array_namesizes), + clip_names=audio_filenames) \ .dump(str(cc_filepath)) -def write_individual_audio_cc_file(clip_dirpath, clip_filename, - cc_filename, header_template_file, array_name, - sampling_rate_value, mono_value, offset_value, - duration_value, res_type_value, min_len): +def write_individual_audio_cc_file( + resampled_audio: AudioSample, + clip_filename, + cc_filename, + header_template_file, + array_name +): + """ + Writes the provided audio sample to a .cc file + + @param resampled_audio: Audio sample to write + @param clip_filename: File name of the clip + @param cc_filename: File name of the .cc file + @param header_template_file: Header template + @param array_name: Name of the array to write + @return: Array length of the audio data written + """ print(f"++ Converting {clip_filename} to {Path(cc_filename).name}") - audio_filepath = Path(clip_dirpath) / clip_filename - clip_data, samplerate = AudioUtils.load_resample_audio_clip(audio_filepath, - sampling_rate_value, mono_value, - offset_value, duration_value, - res_type_value, min_len) # Change from [-1, 1] fp32 range to int16 range. - clip_data = np.clip((clip_data * (1 << 15)), + clip_data = np.clip((resampled_audio.data * (1 << 15)), np.iinfo(np.int16).min, np.iinfo(np.int16).max).flatten().astype(np.int16) - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - file_name=clip_filename, - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, header_template_file, clip_filename) hex_line_generator = (', '.join(map(hex, sub_arr)) - for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data)/20))) + for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data) / 20))) - env.get_template('audio.cc.template').stream(common_template_header=hdr, - size=len(clip_data), - var_name=array_name, - audio_data=hex_line_generator) \ + env \ + .get_template('audio.cc.template') \ + .stream(common_template_header=hdr, + size=len(clip_data), + var_name=array_name, + audio_data=hex_line_generator) \ .dump(str(cc_filename)) return len(clip_data) +def create_audio_cc_file(args, filename, array_name, clip_dirpath): + """ + Create an individual audio cpp file + + @param args: User-specified args + @param filename: Audio filename + @param array_name: Name of the array in the audio .cc file + @param clip_dirpath: Audio file directory path + @return: Array length of the audio data written + """ + cc_filename = (Path(args.source_folder_path) / + (Path(filename).stem.replace(" ", "_") + ".cc")) + audio_filepath = Path(clip_dirpath) / filename + audio_sample = GenUtils.read_audio_file(audio_filepath, args.offset, args.duration) + resampled_audio = GenUtils.resample_audio_clip( + audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples + ) + return write_individual_audio_cc_file( + resampled_audio, filename, cc_filename, args.license_template, array_name, + ) + + def main(args): + """ + Convert audio files to .cc + .hpp files + @param args: Parsed args + """ # Keep the count of the audio files converted audioclip_idx = 0 audioclip_filenames = [] @@ -131,25 +255,41 @@ def main(args): audioclip_filenames.append(filename) # Save the cc file - cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc") array_name = "audio" + str(audioclip_idx) - array_size = write_individual_audio_cc_file(clip_dirpath, filename, cc_filename, args.license_template, array_name, - args.sampling_rate, args.mono, args.offset, - args.duration, args.res_type, args.min_samples) + array_size = create_audio_cc_file(args, filename, array_name, clip_dirpath) audioclip_array_names.append((array_name, array_size)) # Increment audio index audioclip_idx = audioclip_idx + 1 - except: + except OSError: if args.verbosity: print(f"Failed to open {filename} as an audio.") if len(audioclip_filenames) > 0: - write_hpp_file(header_filepath, common_cc_filepath, args.license_template, - audioclip_idx, audioclip_filenames, audioclip_array_names) + header = env \ + .get_template(args.license_template) \ + .render(script_name=Path(__file__).name, + gen_time=datetime.datetime.now(), + year=datetime.datetime.now().year) + + write_hpp_file( + header_filepath, + header, + audioclip_idx, + audioclip_array_names + ) + + write_cc_file( + common_cc_filepath, + header, + audioclip_idx, + audioclip_filenames, + audioclip_array_names + ) + else: raise FileNotFoundError("No valid audio clip files found.") if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_default_input_cpp.py b/scripts/py/gen_default_input_cpp.py index 093a606..6056dc1 100644 --- a/scripts/py/gen_default_input_cpp.py +++ b/scripts/py/gen_default_input_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,38 +16,61 @@ """ Utility script to generate the minimum InputFiles.hpp and cpp files required by an application. """ -import datetime -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path from jinja2 import Environment, FileSystemLoader +from gen_utils import GenUtils + parser = ArgumentParser() -parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.") -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -args = parser.parse_args() + +# pylint: disable=duplicate-code +parser.add_argument( + "--header_folder_path", + type=str, + help="path to header folder to be generated." +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" + +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) +# pylint: enable=duplicate-code def write_hpp_file(header_file_path, header_template_file): + """ + Write .hpp file + @param header_file_path: Header file path + @param header_template_file: Header template file + """ print(f"++ Generating {header_file_path}") - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) - env.get_template('default.hpp.template').stream(common_template_header=hdr) \ + hdr = GenUtils.gen_header(env, header_template_file) + env \ + .get_template('default.hpp.template') \ + .stream(common_template_header=hdr) \ .dump(str(header_file_path)) def main(args): + """ + Generate InputFiles.hpp + .cpp + @param args: Parsed args + """ header_filename = "InputFiles.hpp" header_filepath = Path(args.header_folder_path) / header_filename write_hpp_file(header_filepath, args.license_template) if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_labels_cpp.py b/scripts/py/gen_labels_cpp.py index 065ed5d..11d5040 100644 --- a/scripts/py/gen_labels_cpp.py +++ b/scripts/py/gen_labels_cpp.py @@ -1,6 +1,6 @@ #!env/bin/python3 -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,47 +21,83 @@ NN model output vector) into a vector list initialiser. The intention is for this script to be called as part of the build framework to auto-generate the cpp file with labels that can be used in the application without modification. """ -import datetime -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path from jinja2 import Environment, FileSystemLoader +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() # Label file path -parser.add_argument("--labels_file", type=str, help="Path to the label text file", required=True) +parser.add_argument( + "--labels_file", + type=str, + help="Path to the label text file", + required=True +) + # Output file to be generated -parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.", required=True) -parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.", required=True) -parser.add_argument("--output_file_name", type=str, help="Required output file name", required=True) +parser.add_argument( + "--source_folder_path", + type=str, + help="path to source folder to be generated.", + required=True +) + +parser.add_argument( + "--header_folder_path", + type=str, + help="path to header folder to be generated.", + required=True +) + +parser.add_argument( + "--output_file_name", + type=str, + help="Required output file name", + required=True +) + # Namespaces -parser.add_argument("--namespaces", action='append', default=[]) +parser.add_argument( + "--namespaces", + action='append', + default=[] +) + # License template -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) -args = parser.parse_args() +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) +# pylint: enable=duplicate-code def main(args): + """ + Generate labels .cpp + @param args: Parsed args + """ # Get the labels from text file - with open(args.labels_file, "r") as f: + with open(args.labels_file, "r", encoding="utf8") as f: labels = f.read().splitlines() # No labels? if len(labels) == 0: - raise Exception(f"no labels found in {args.label_file}") + raise ValueError(f"no labels found in {args.label_file}") - header_template = env.get_template(args.license_template) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - file_name=Path(args.labels_file).name, - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, args.license_template, Path(args.labels_file).name) hpp_filename = Path(args.header_folder_path) / (args.output_file_name + ".hpp") env.get_template('Labels.hpp.template').stream(common_template_header=hdr, @@ -78,4 +114,4 @@ def main(args): if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py index e4933b5..933c189 100644 --- a/scripts/py/gen_model_cpp.py +++ b/scripts/py/gen_model_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the project directly. This should be called as part of cmake framework should the models need to be generated at configuration stage. """ -import datetime +import binascii from argparse import ArgumentParser from pathlib import Path from jinja2 import Environment, FileSystemLoader -import binascii +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True) -parser.add_argument("--output_dir", help="Output directory", required=True) -parser.add_argument('-e', '--expression', action='append', default=[], dest="expr") -parser.add_argument('--header', action='append', default=[], dest="headers") -parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces") -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -args = parser.parse_args() +parser.add_argument( + "--tflite_path", + help="Model (.tflite) path", + required=True +) + +parser.add_argument( + "--output_dir", + help="Output directory", + required=True +) + +parser.add_argument( + '-e', + '--expression', + action='append', + default=[], + dest="expr" +) + +parser.add_argument( + '--header', + action='append', + default=[], + dest="headers" +) + +parser.add_argument( + '-ns', + '--namespaces', + action='append', + default=[], + dest="namespaces" +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) +# pylint: enable=duplicate-code def get_tflite_data(tflite_path: str) -> list: """ Reads a binary file and returns a C style array as a @@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list: for i in range(0, len(hexstream), 2): if 0 == (i % hex_digits_per_line): hexstring += "\n" - hexstring += '0x' + hexstream[i:i+2] + ", " + hexstring += '0x' + hexstream[i:i + 2] + ", " hexstring += '};\n' return [hexstring] def main(args): + """ + Generate models .cpp + @param args: Parsed args + """ if not Path(args.tflite_path).is_file(): - raise Exception(f"{args.tflite_path} not found") + raise ValueError(f"{args.tflite_path} not found") # Cpp filename: cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve() @@ -80,19 +122,16 @@ def main(args): cpp_filename.parent.mkdir(exist_ok=True) - header_template = env.get_template(args.license_template) - - hdr = header_template.render(script_name=Path(__file__).name, - file_name=Path(args.tflite_path).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name) - env.get_template('tflite.cc.template').stream(common_template_header=hdr, - model_data=get_tflite_data(args.tflite_path), - expressions=args.expr, - additional_headers=args.headers, - namespaces=args.namespaces).dump(str(cpp_filename)) + env \ + .get_template('tflite.cc.template') \ + .stream(common_template_header=hdr, + model_data=get_tflite_data(args.tflite_path), + expressions=args.expr, + additional_headers=args.headers, + namespaces=args.namespaces).dump(str(cpp_filename)) if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_rgb_cpp.py b/scripts/py/gen_rgb_cpp.py index b8d85ee..e1c93bb 100644 --- a/scripts/py/gen_rgb_cpp.py +++ b/scripts/py/gen_rgb_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,87 +18,179 @@ Utility script to convert a set of RGB images in a given location into corresponding cpp files and a single hpp file referencing the vectors from the cpp files. """ -import datetime import glob import math -from pathlib import Path +import typing from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path import numpy as np from PIL import Image, UnidentifiedImageError from jinja2 import Environment, FileSystemLoader +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--image_path", type=str, help="path to images folder or image file to convert.") -parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.") -parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.") -parser.add_argument("--image_size", type=int, nargs=2, help="Size (width and height) of the converted images.") -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -args = parser.parse_args() + +parser.add_argument( + "--image_path", + type=str, + help="path to images folder or image file to convert." +) + +parser.add_argument( + "--source_folder_path", + type=str, + help="path to source folder to be generated." +) + +parser.add_argument( + "--header_folder_path", + type=str, + help="path to header folder to be generated." +) + +parser.add_argument( + "--image_size", + type=int, + nargs=2, + help="Size (width and height) of the converted images." +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) -def write_hpp_file(header_file_path, cc_file_path, header_template_file, num_images, image_filenames, - image_array_names, image_size): +# pylint: enable=duplicate-code +@dataclass +class ImagesParams: + """ + Template params for Images.hpp and Images.cc + """ + num_images: int + image_size: typing.Sequence + image_array_names: typing.List[str] + image_filenames: typing.List[str] + + +def write_hpp_file( + images_params: ImagesParams, + header_file_path: Path, + cc_file_path: Path, + header_template_file: str, +): + """ + Write Images.hpp and Images.cc + + @param images_params: Template params + @param header_file_path: Images.hpp path + @param cc_file_path: Images.cc path + @param header_template_file: Header template file name + """ print(f"++ Generating {header_file_path}") - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) - env.get_template('Images.hpp.template').stream(common_template_header=hdr, - imgs_count=num_images, - img_size=str(image_size[0] * image_size[1] * 3), - var_names=image_array_names) \ + hdr = GenUtils.gen_header(env, header_template_file) + + image_size = str(images_params.image_size[0] * images_params.image_size[1] * 3) + + env \ + .get_template('Images.hpp.template') \ + .stream(common_template_header=hdr, + imgs_count=images_params.num_images, + img_size=image_size, + var_names=images_params.image_array_names) \ .dump(str(header_file_path)) - env.get_template('Images.cc.template').stream(common_template_header=hdr, - var_names=image_array_names, - img_names=image_filenames) \ + env \ + .get_template('Images.cc.template') \ + .stream(common_template_header=hdr, + var_names=images_params.image_array_names, + img_names=images_params.image_filenames) \ .dump(str(cc_file_path)) -def write_individual_img_cc_file(image_filename, cc_filename, header_template_file, original_image, - image_size, array_name): - print(f"++ Converting {image_filename} to {cc_filename.name}") +def resize_crop_image( + original_image: Image.Image, + image_size: typing.Sequence +) -> np.ndarray: + """ + Resize and crop input image - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - file_name=image_filename, - year=datetime.datetime.now().year) + @param original_image: Image to resize and crop + @param image_size: New image size + @return: Resized and cropped image + """ # IFM size ifm_width = image_size[0] ifm_height = image_size[1] # Aspect ratio resize - scale_ratio = (float)(max(ifm_width, ifm_height)) / (float)(min(original_image.size[0], original_image.size[1])) - resized_width = (int)(original_image.size[0] * scale_ratio) - resized_height = (int)(original_image.size[1] * scale_ratio) - resized_image = original_image.resize([resized_width,resized_height], Image.Resampling.BILINEAR) + scale_ratio = (float(max(ifm_width, ifm_height)) + / float(min(original_image.size[0], original_image.size[1]))) + resized_width = int(original_image.size[0] * scale_ratio) + resized_height = int(original_image.size[1] * scale_ratio) + resized_image = original_image.resize( + size=(resized_width, resized_height), + resample=Image.Resampling.BILINEAR + ) # Crop the center of the image resized_image = resized_image.crop(( - (resized_width - ifm_width) / 2, # left - (resized_height - ifm_height) / 2, # top - (resized_width + ifm_width) / 2, # right + (resized_width - ifm_width) / 2, # left + (resized_height - ifm_height) / 2, # top + (resized_width + ifm_width) / 2, # right (resized_height + ifm_height) / 2 # bottom - )) + )) + + return np.array(resized_image, dtype=np.uint8).flatten() + + +def write_individual_img_cc_file( + rgb_data: np.ndarray, + image_filename: str, + cc_filename: Path, + header_template_file: str, + array_name: str +): + """ + Write image.cc + + @param rgb_data: Image data + @param image_filename: Image file name + @param cc_filename: image.cc path + @param header_template_file: Header template file name + @param array_name: C++ array name + """ + print(f"++ Converting {image_filename} to {cc_filename.name}") + + hdr = GenUtils.gen_header(env, header_template_file, image_filename) - # Convert the image and write it to the cc file - rgb_data = np.array(resized_image, dtype=np.uint8).flatten() hex_line_generator = (', '.join(map(hex, sub_arr)) for sub_arr in np.array_split(rgb_data, math.ceil(len(rgb_data) / 20))) - env.get_template('image.cc.template').stream(common_template_header=hdr, - var_name=array_name, - img_data=hex_line_generator) \ + env \ + .get_template('image.cc.template') \ + .stream(common_template_header=hdr, + var_name=array_name, + img_data=hex_line_generator) \ .dump(str(cc_filename)) def main(args): + """ + Convert images + @param args: Parsed args + """ # Keep the count of the images converted image_idx = 0 image_filenames = [] @@ -123,26 +215,29 @@ def main(args): image_filenames.append(filename) # Save the cc file - cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc") + cc_filename = (Path(args.source_folder_path) / + (Path(filename).stem.replace(" ", "_") + ".cc")) array_name = "im" + str(image_idx) image_array_names.append(array_name) - write_individual_img_cc_file(filename, cc_filename, args.license_template, - original_image, args.image_size, array_name) + + rgb_data = resize_crop_image(original_image, args.image_size) + write_individual_img_cc_file( + rgb_data, filename, cc_filename, args.license_template, array_name + ) # Increment image index image_idx = image_idx + 1 - header_filename = "InputFiles.hpp" - header_filepath = Path(args.header_folder_path) / header_filename - common_cc_filename = "InputFiles.cc" - common_cc_filepath = Path(args.source_folder_path) / common_cc_filename + header_filepath = Path(args.header_folder_path) / "InputFiles.hpp" + common_cc_filepath = Path(args.source_folder_path) / "InputFiles.cc" + + images_params = ImagesParams(image_idx, args.image_size, image_array_names, image_filenames) if len(image_filenames) > 0: - write_hpp_file(header_filepath, common_cc_filepath, args.license_template, - image_idx, image_filenames, image_array_names, args.image_size) + write_hpp_file(images_params, header_filepath, common_cc_filepath, args.license_template) else: raise FileNotFoundError("No valid images found.") if __name__ == '__main__': - main(args) + main(parsed_args) diff --git a/scripts/py/gen_test_data_cpp.py b/scripts/py/gen_test_data_cpp.py index a9e2b75..1ee55ff 100644 --- a/scripts/py/gen_test_data_cpp.py +++ b/scripts/py/gen_test_data_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 - 2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,81 +18,170 @@ Utility script to convert a set of pairs of npy files in a given location into corresponding cpp files and a single hpp file referencing the vectors from the cpp files. """ -import datetime import math -import os -import numpy as np +import typing +from argparse import ArgumentParser +from dataclasses import dataclass from pathlib import Path -from argparse import ArgumentParser +import numpy as np from jinja2 import Environment, FileSystemLoader +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--data_folder_path", type=str, help="path to ifm-ofm npy folder to convert.") -parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.") -parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.") -parser.add_argument("--usecase", type=str, default="", help="Test data file suffix.") -parser.add_argument("--namespaces", action='append', default=[]) -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -parser.add_argument("-v", "--verbosity", action="store_true") -args = parser.parse_args() +parser.add_argument( + "--data_folder_path", + type=str, + help="path to ifm-ofm npy folder to convert." +) + +parser.add_argument( + "--source_folder_path", + type=str, + help="path to source folder to be generated." +) + +parser.add_argument( + "--header_folder_path", + type=str, + help="path to header folder to be generated." +) + +parser.add_argument( + "--usecase", + type=str, + default="", + help="Test data file suffix." +) + +parser.add_argument( + "--namespaces", + action='append', + default=[] +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parser.add_argument( + "-v", + "--verbosity", + action="store_true" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) -def write_hpp_file(header_filename, cc_file_path, header_template_file, num_ifms, num_ofms, - ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type): - header_file_path = Path(args.header_folder_path) / header_filename +# pylint: enable=duplicate-code +@dataclass +class TestDataParams: + """ + Template params for TestData.hpp + TestData.ccc + """ + ifm_count: int + ofm_count: int + ifm_var_names: typing.List[str] + ifm_var_sizes: typing.List[int] + ofm_var_names: typing.List[str] + ofm_var_sizes: typing.List[int] + data_type: str + + +@dataclass +class IofmParams: + """ + Template params for iofmdata.cc + """ + var_name: str + data_type: str + + +def write_hpp_file( + template_params: TestDataParams, + header_filename: str, + cc_file_path: str, + header_template_file: str +): + """ + Write TestData.hpp and TestData.cc + + @param template_params: Template parameters + @param header_filename: TestData.hpp path + @param cc_file_path: TestData.cc path + @param header_template_file: Header template file name + """ + header_file_path = Path(parsed_args.header_folder_path) / header_filename print(f"++ Generating {header_file_path}") - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) - env.get_template('TestData.hpp.template').stream(common_template_header=hdr, - ifm_count=num_ifms, - ofm_count=num_ofms, - ifm_var_names=ifm_array_names, - ifm_var_sizes=ifm_sizes, - ofm_var_names=ofm_array_names, - ofm_var_sizes=ofm_sizes, - data_type=iofm_data_type, - namespaces=args.namespaces) \ + hdr = GenUtils.gen_header(env, header_template_file) + env \ + .get_template('TestData.hpp.template') \ + .stream(common_template_header=hdr, + ifm_count=template_params.ifm_count, + ofm_count=template_params.ofm_count, + ifm_var_names=template_params.ifm_var_names, + ifm_var_sizes=template_params.ifm_var_sizes, + ofm_var_names=template_params.ofm_var_names, + ofm_var_sizes=template_params.ofm_var_sizes, + data_type=template_params.data_type, + namespaces=parsed_args.namespaces) \ .dump(str(header_file_path)) - env.get_template('TestData.cc.template').stream(common_template_header=hdr, - include_h=header_filename, - ifm_var_names=ifm_array_names, - ofm_var_names=ofm_array_names, - data_type=iofm_data_type, - namespaces=args.namespaces) \ + env \ + .get_template('TestData.cc.template') \ + .stream(common_template_header=hdr, + include_h=header_filename, + ifm_var_names=template_params.ifm_var_names, + ofm_var_names=template_params.ofm_var_names, + data_type=template_params.data_type, + namespaces=parsed_args.namespaces) \ .dump(str(cc_file_path)) -def write_individual_cc_file(filename, cc_filename, header_filename, header_template_file, array_name, iofm_data_type): +def write_individual_cc_file( + template_params: IofmParams, + header_filename: str, + filename: str, + cc_filename: Path, + header_template_file: str +): + """ + Write iofmdata.cc + + @param template_params: Template parameters + @param header_filename: Header file name + @param filename: Input file name + @param cc_filename: iofmdata.cc file name + @param header_template_file: Header template file name + """ print(f"++ Converting {filename} to {cc_filename.name}") - header_template = env.get_template(header_template_file) - hdr = header_template.render(script_name=Path(__file__).name, - gen_time=datetime.datetime.now(), - file_name=filename, - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, header_template_file, filename) # Convert the image and write it to the cc file - fm_data = (np.load(Path(args.data_folder_path) / filename)).flatten() + fm_data = (np.load(Path(parsed_args.data_folder_path) / filename)).flatten() type(fm_data.dtype) hex_line_generator = (', '.join(map(hex, sub_arr)) for sub_arr in np.array_split(fm_data, math.ceil(len(fm_data) / 20))) - env.get_template('iofmdata.cc.template').stream(common_template_header=hdr, - include_h=header_filename, - var_name=array_name, - fm_data=hex_line_generator, - data_type=iofm_data_type, - namespaces=args.namespaces) \ + env \ + .get_template('iofmdata.cc.template') \ + .stream(common_template_header=hdr, + include_h=header_filename, + var_name=template_params.var_name, + fm_data=hex_line_generator, + data_type=template_params.data_type, + namespaces=parsed_args.namespaces) \ .dump(str(cc_filename)) @@ -104,59 +193,117 @@ def get_npy_vec_size(filename: str) -> int: Return: size in bytes """ - data = np.load(Path(args.data_folder_path) / filename) + data = np.load(Path(parsed_args.data_folder_path) / filename) return data.size * data.dtype.itemsize -def main(args): - # Keep the count of the images converted - ifm_array_names = [] - ofm_array_names = [] +def write_cc_files(args, count, iofm_data_type, add_usecase_fname, prefix): + """ + Write all cc files + + @param args: User-provided args + @param count: File count + @param iofm_data_type: Data type + @param add_usecase_fname: Use case suffix + @param prefix: Prefix (ifm/ofm) + @return: Names and sizes of generated C++ arrays + """ + array_names = [] + sizes = [] + + header_filename = get_header_filename(add_usecase_fname) + # In the data_folder_path there should be pairs of ifm-ofm + # It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy + # count = int(len(list(Path(args.data_folder_path).glob(f'{prefix}*.npy')))) + + for idx in range(count): + # Save the fm cc file + base_name = prefix + str(idx) + filename = base_name + ".npy" + array_name = base_name + add_usecase_fname + cc_filename = Path(args.source_folder_path) / (array_name + ".cc") + array_names.append(array_name) + + template_params = IofmParams( + var_name=array_name, + data_type=iofm_data_type, + ) + + write_individual_cc_file( + template_params, header_filename, filename, cc_filename, args.license_template + ) + sizes.append(get_npy_vec_size(filename)) + + return array_names, sizes + + +def get_header_filename(use_case_filename): + """ + Get the header file name from the use case file name + + @param use_case_filename: The use case file name + @return: The header file name + """ + return "TestData" + use_case_filename + ".hpp" + + +def get_cc_filename(use_case_filename): + """ + Get the cc file name from the use case file name + + @param use_case_filename: The use case file name + @return: The cc file name + """ + return "TestData" + use_case_filename + ".cc" + + +def main(args): + """ + Generate test data + @param args: Parsed args + """ add_usecase_fname = ("_" + args.usecase) if (args.usecase != "") else "" - header_filename = "TestData" + add_usecase_fname + ".hpp" - common_cc_filename = "TestData" + add_usecase_fname + ".cc" + header_filename = get_header_filename(add_usecase_fname) + common_cc_filename = get_cc_filename(add_usecase_fname) # In the data_folder_path there should be pairs of ifm-ofm # It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy ifms_count = int(len(list(Path(args.data_folder_path).glob('ifm*.npy')))) ofms_count = int(len(list(Path(args.data_folder_path).glob('ofm*.npy')))) - #i_ofms_count = int(len([name for name in os.listdir(os.path.join(args.data_folder_path)) if name.lower().endswith('.npy')]) / 2) - iofm_data_type = "int8_t" if ifms_count > 0: - iofm_data_type = "int8_t" if (np.load(Path(args.data_folder_path) / "ifm0.npy").dtype == np.int8) else "uint8_t" - - ifm_sizes = [] - ofm_sizes = [] + iofm_data_type = "int8_t" \ + if (np.load(str(Path(args.data_folder_path) / "ifm0.npy")).dtype == np.int8) \ + else "uint8_t" - for idx in range(ifms_count): - # Save the fm cc file - base_name = "ifm" + str(idx) - filename = base_name+".npy" - array_name = base_name + add_usecase_fname - cc_filename = Path(args.source_folder_path) / (array_name + ".cc") - ifm_array_names.append(array_name) - write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type) - ifm_sizes.append(get_npy_vec_size(filename)) + ifm_array_names, ifm_sizes = write_cc_files( + args, ifms_count, iofm_data_type, add_usecase_fname, prefix="ifm" + ) - for idx in range(ofms_count): - # Save the fm cc file - base_name = "ofm" + str(idx) - filename = base_name+".npy" - array_name = base_name + add_usecase_fname - cc_filename = Path(args.source_folder_path) / (array_name + ".cc") - ofm_array_names.append(array_name) - write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type) - ofm_sizes.append(get_npy_vec_size(filename)) + ofm_array_names, ofm_sizes = write_cc_files( + args, ofms_count, iofm_data_type, add_usecase_fname, prefix="ofm" + ) common_cc_filepath = Path(args.source_folder_path) / common_cc_filename - write_hpp_file(header_filename, common_cc_filepath, args.license_template, - ifms_count, ofms_count, ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type) + + template_params = TestDataParams( + ifm_count=ifms_count, + ofm_count=ofms_count, + ifm_var_names=ifm_array_names, + ifm_var_sizes=ifm_sizes, + ofm_var_names=ofm_array_names, + ofm_var_sizes=ofm_sizes, + data_type=iofm_data_type, + ) + + write_hpp_file( + template_params, header_filename, common_cc_filepath, args.license_template + ) if __name__ == '__main__': - if args.verbosity: - print("Running gen_test_data_cpp with args: "+str(args)) - main(args) + if parsed_args.verbosity: + print("Running gen_test_data_cpp with args: " + str(parsed_args)) + main(parsed_args) diff --git a/scripts/py/gen_utils.py b/scripts/py/gen_utils.py index ee33705..6bb4760 100644 --- a/scripts/py/gen_utils.py +++ b/scripts/py/gen_utils.py @@ -1,6 +1,6 @@ #!env/bin/python3 -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +14,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import soundfile as sf -import resampy +""" +Utility functions for .cc + .hpp file generation +""" +import argparse +import datetime +from dataclasses import dataclass +from pathlib import Path + +import jinja2 import numpy as np +import resampy +import soundfile as sf -class AudioUtils: +@dataclass +class AudioSample: + """ + Represents an audio sample with its sample rate + """ + data: np.ndarray + sample_rate: int + + +class GenUtils: + """ + Class with utility functions for audio and other .cc + .hpp file generation + """ + @staticmethod def res_data_type(res_type_value): """ Returns the input string if is one of the valid resample type """ - import argparse - if res_type_value not in AudioUtils.res_type_list(): - raise argparse.ArgumentTypeError(f"{res_type_value} not valid. Supported only {AudioUtils.res_type_list()}") + if res_type_value not in GenUtils.res_type_list(): + raise argparse.ArgumentTypeError( + f"{res_type_value} not valid. Supported only {GenUtils.res_type_list()}" + ) return res_type_value @staticmethod @@ -39,27 +61,18 @@ class AudioUtils: return ['kaiser_best', 'kaiser_fast'] @staticmethod - def load_resample_audio_clip(path, target_sr=16000, mono=True, offset=0.0, duration=0, res_type='kaiser_best', - min_len=16000): + def read_audio_file( + path, + offset, + duration + ) -> AudioSample: """ - Load and resample an audio clip with the given desired specs. + Reads an audio file to an array - Parameters: - ---------- - path (string): Path to the input audio clip. - target_sr (int, optional): Target sampling rate. Positive number are considered valid, - if zero or negative the native sampling rate of the file will be preserved. Default is 16000. - mono (bool, optional): Specify if the audio file needs to be converted to mono. Default is True. - offset (float, optional): Target sampling rate. Default is 0.0. - duration (int, optional): Target duration. Positive number are considered valid, - if zero or negative the duration of the file will be preserved. Default is 0. - res_type (int, optional): Resample type to use, Default is 'kaiser_best'. - min_len (int, optional): Minimun lenght of the output audio time series. Default is 16000. - - Returns: - ---------- - y (np.ndarray): Output audio time series of shape shape=(n,) or (2, n). - sr (int): A scalar number > 0 that represent the sampling rate of `y` + @param path: Path to audio file + @param offset: Offset to read from + @param duration: Duration to read + @return: The audio data and the sample rate """ try: with sf.SoundFile(path) as audio_file: @@ -76,40 +89,115 @@ class AudioUtils: # Load the target number of frames y = audio_file.read(frames=num_frame_duration, dtype=np.float32, always_2d=False).T - - except: + except OSError as err: print(f"Failed to open {path} as an audio.") + raise err + + return AudioSample(y, origin_sr) + + @staticmethod + def _resample_audio( + y, + target_sr, + origin_sr, + res_type + ): + """ + Resamples audio to a different sample rate + + @param y: Audio to resample + @param target_sr: Target sample rate + @param origin_sr: Original sample rate + @param res_type: Resample type + @return: The resampled audio + """ + ratio = float(target_sr) / origin_sr + axis = -1 + n_samples = int(np.ceil(y.shape[axis] * ratio)) + + # Resample using resampy + y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis) + n_rs_samples = y_rs.shape[axis] + + # Adjust the size + if n_rs_samples > n_samples: + slices = [slice(None)] * y_rs.ndim + slices[axis] = slice(0, n_samples) + y = y_rs[tuple(slices)] + elif n_rs_samples < n_samples: + lengths = [(0, 0)] * y_rs.ndim + lengths[axis] = (0, n_samples - n_rs_samples) + y = np.pad(y_rs, lengths, 'constant', constant_values=0) + + return y + + @staticmethod + def resample_audio_clip( + audio_sample: AudioSample, + target_sr=16000, + mono=True, + res_type='kaiser_best', + min_len=16000 + ) -> AudioSample: + """ + Load and resample an audio clip with the given desired specs. + + Parameters: + ---------- + path (string): Path to the input audio clip. + target_sr (int, optional): Target sampling rate. Positive number are considered valid, + if zero or negative the native sampling rate of the file + will be preserved. Default is 16000. + mono (bool, optional): Specify if the audio file needs to be converted to mono. + Default is True. + offset (float, optional): Target sampling rate. Default is 0.0. + duration (int, optional): Target duration. Positive number are considered valid, + if zero or negative the duration of the file + will be preserved. Default is 0. + res_type (int, optional): Resample type to use, Default is 'kaiser_best'. + min_len (int, optional): Minimum length of the output audio time series. + Default is 16000. + + Returns: + ---------- + y (np.ndarray): Output audio time series of shape=(n,) or (2, n). + sample_rate (int): A scalar number > 0 that represent the sampling rate of `y` + """ + y = audio_sample.data.copy() # Convert to mono if requested and if audio has more than one dimension - if mono and (y.ndim > 1): + if mono and (audio_sample.data.ndim > 1): y = np.mean(y, axis=0) - if not (origin_sr == target_sr) and (target_sr > 0): - ratio = float(target_sr) / origin_sr - axis = -1 - n_samples = int(np.ceil(y.shape[axis] * ratio)) - - # Resample using resampy - y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis) - n_rs_samples = y_rs.shape[axis] - - # Adjust the size - if n_rs_samples > n_samples: - slices = [slice(None)] * y_rs.ndim - slices[axis] = slice(0, n_samples) - y = y_rs[tuple(slices)] - elif n_rs_samples < n_samples: - lengths = [(0, 0)] * y_rs.ndim - lengths[axis] = (0, n_samples - n_rs_samples) - y = np.pad(y_rs, lengths, 'constant', constant_values=(0)) - - sr = target_sr + if not (audio_sample.sample_rate == target_sr) and (target_sr > 0): + y = GenUtils._resample_audio(y, target_sr, audio_sample.sample_rate, res_type) + sample_rate = target_sr else: - sr = origin_sr + sample_rate = audio_sample.sample_rate # Pad if necessary and min lenght is setted (min_len> 0) if (y.shape[0] < min_len) and (min_len > 0): sample_to_pad = min_len - y.shape[0] - y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=(0)) + y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=0) + + return AudioSample(data=y, sample_rate=sample_rate) - return y, sr + @staticmethod + def gen_header( + env: jinja2.Environment, + header_template_file: str, + file_name: str = None + ) -> str: + """ + Generate common licence header + + :param env: Jinja2 environment + :param header_template_file: Path to the licence header template + :param file_name: Optional generating script file name + :return: Generated licence header as a string + """ + header_template = env.get_template(header_template_file) + return header_template.render(script_name=Path(__file__).name, + gen_time=datetime.datetime.now(), + file_name=file_name, + year=datetime.datetime.now().year) diff --git a/scripts/py/git_pre_push_hooks.sh b/scripts/py/git_pre_push_hooks.sh new file mode 100755 index 0000000..db5706f --- /dev/null +++ b/scripts/py/git_pre_push_hooks.sh @@ -0,0 +1,48 @@ +#!/bin/sh +# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Called by "git push" with no arguments. The hook should +# exit with non-zero status after issuing an appropriate message if +# it wants to stop the push. + +# shellcheck disable=SC2034,SC2162 +while read local_ref local_sha remote_ref remote_sha; do + # We should pass only added or modified C/C++ source files to cppcheck. + changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2) + if [ -n "$changed_files" ]; then + # shellcheck disable=SC2086 + clang-format -style=file --dry-run --Werror $changed_files + + exitcode1=$? + if [ $exitcode1 -ne 0 ]; then + echo "Formatting errors found in file: $changed_files. \ + Please run: + \"clang-format -style=file -i $changed_files\" + to correct these errors" + exit $exitcode1 + fi + + # shellcheck disable=SC2086 + cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files + exitcode2=$? + if [ $exitcode2 -ne 0 ]; then + exit $exitcode2 + fi + fi + exit 0 +done + +exit 0 diff --git a/scripts/py/rnnoise_dump_extractor.py b/scripts/py/rnnoise_dump_extractor.py index 715b922..9e6ff1f 100644 --- a/scripts/py/rnnoise_dump_extractor.py +++ b/scripts/py/rnnoise_dump_extractor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,47 +20,84 @@ Example use: python rnnoise_dump_extractor.py --dump_file output.bin --output_dir ./denoised_wavs/ """ -import soundfile as sf -import numpy as np - import argparse -from os import path import struct +import typing +from os import path + +import numpy as np +import soundfile as sf -def extract(fp, output_dir, export_npy): +def extract( + dump_file: typing.IO, + output_dir: str, + export_npy: bool +): + """ + Extract audio file from RNNoise output dump + + @param dump_file: Audio dump file location + @param output_dir: Output direction + @param export_npy: Whether to export the audio as .npy + """ while True: - filename_length = struct.unpack("i", fp.read(4))[0] + filename_length = struct.unpack("i", dump_file.read(4))[0] if filename_length == -1: return - filename = struct.unpack("{}s".format(filename_length), fp.read(filename_length))[0].decode('ascii') - audio_clip_length = struct.unpack("I", fp.read(4))[0] - output_file_name = path.join(output_dir, "denoised_{}".format(filename)) - audio_clip = fp.read(audio_clip_length) - - with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16", endian="LITTLE") as wav_file: + filename = struct \ + .unpack(f"{filename_length}s", dump_file.read(filename_length))[0] \ + .decode('ascii') + + audio_clip_length = struct.unpack("I", dump_file.read(4))[0] + output_file_name = path.join(output_dir, f"denoised_{filename}") + audio_clip = dump_file.read(audio_clip_length) + + with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16", + endian="LITTLE") as wav_file: wav_file.buffer_write(audio_clip, dtype='int16') - print("{} written to disk".format(output_file_name)) + print(f"{output_file_name} written to disk") if export_npy: output_file_name += ".npy" - pack_format = "{}h".format(int(audio_clip_length/2)) + pack_format = f"{int(audio_clip_length / 2)}h" npdata = np.array(struct.unpack(pack_format, audio_clip)).astype(np.int16) np.save(output_file_name, npdata) - print("{} written to disk".format(output_file_name)) + print(f"{output_file_name} written to disk") def main(args): + """ + Run RNNoise audio dump extraction + @param args: Parsed args + """ extract(args.dump_file, args.output_dir, args.export_npy) parser = argparse.ArgumentParser() -parser.add_argument("--dump_file", type=argparse.FileType('rb'), help="Dump file with audio files to extract.", required=True) -parser.add_argument("--output_dir", help="Output directory, Warning: Duplicated file names will be overwritten.", required=True) -parser.add_argument("--export_npy", help="Export the audio buffer in NumPy format", action="store_true") -args = parser.parse_args() + +parser.add_argument( + "--dump_file", + type=argparse.FileType('rb'), + help="Dump file with audio files to extract.", + required=True +) + +parser.add_argument( + "--output_dir", + help="Output directory, Warning: Duplicated file names will be overwritten.", + required=True +) + +parser.add_argument( + "--export_npy", + help="Export the audio buffer in NumPy format", + action="store_true" +) + +parsed_args = parser.parse_args() if __name__ == "__main__": - main(args) + main(parsed_args) diff --git a/scripts/py/setup_hooks.py b/scripts/py/setup_hooks.py index ead5e1f..dc3156c 100644 --- a/scripts/py/setup_hooks.py +++ b/scripts/py/setup_hooks.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,84 +12,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import os -import sys +""" +Adds the git hooks script into the appropriate location +""" import argparse +import os +import shutil import subprocess -import stat +import sys +from pathlib import Path + +HOOKS_SCRIPT = "git_pre_push_hooks.sh" + -def set_hooks_dir(hooks_dir): - command = 'git config core.hooksPath {}'.format(hooks_dir) - subprocess.Popen(command.split(), stdout=subprocess.PIPE) +def set_hooks_dir(hooks_dir: str): + """ + Set the hooks path in the git configuration + @param hooks_dir: The hooks directory + """ + command = f'git config core.hooksPath {hooks_dir}' + with subprocess.Popen(command.split(), stdout=subprocess.PIPE) as process: + process.communicate() + return_code = process.returncode -def add_pre_push_hooks(hooks_dir): + if return_code != 0: + raise RuntimeError(f"Could not configure git hooks path, exited with code {return_code}") + + +def add_pre_push_hooks(hooks_dir: str): + """ + Copies the git hooks scripts into the specified location + @param hooks_dir: The specified git hooks directory + """ pre_push = "pre-push" file_path = os.path.join(hooks_dir, pre_push) file_exists = os.path.exists(file_path) if file_exists: os.remove(file_path) - f = open(file_path, "a") - f.write( -'''#!/bin/sh -# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Called by "git push" with no arguments. The hook should -# exit with non-zero status after issuing an appropriate message if -# it wants to stop the push. - -while read local_ref local_sha remote_ref remote_sha -do - # We should pass only added or modified C/C++ source files to cppcheck. - changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2) - if [ -n "$changed_files" ]; then - clang-format -style=file --dry-run --Werror $changed_files - - exitcode1=$? - if [ $exitcode1 -ne 0 ]; then - echo "Formatting errors found in file: $changed_files. - \nPlease run:\n\ \"clang-format -style=file -i $changed_files\" - \nto correct these errors" - exit $exitcode1 - fi - - cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files - exitcode2=$? - if [ $exitcode2 -ne 0 ]; then - exit $exitcode2 - fi - fi - exit 0 -done -exit 0''' -) + script_path = Path(__file__).resolve().parent / HOOKS_SCRIPT + shutil.copy(script_path, hooks_dir) - f.close() - s = os.stat(file_path) - os.chmod(file_path, s.st_mode | stat.S_IEXEC) -parser = argparse.ArgumentParser() -parser.add_argument("git_hooks_path") -args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("git_hooks_path") + args = parser.parse_args() -dir_exists = os.path.exists(args.git_hooks_path) -if not dir_exists: - print('Error! The Git hooks directory you supplied does not exist.') - sys.exit() + if not os.path.exists(args.git_hooks_path): + print('Error! The Git hooks directory you supplied does not exist.') + sys.exit() -add_pre_push_hooks(args.git_hooks_path) -set_hooks_dir(args.git_hooks_path) + add_pre_push_hooks(args.git_hooks_path) + set_hooks_dir(args.git_hooks_path) diff --git a/scripts/py/templates/header_template.txt b/scripts/py/templates/header_template.txt index f6e3bdb..32bf71a 100644 --- a/scripts/py/templates/header_template.txt +++ b/scripts/py/templates/header_template.txt @@ -16,6 +16,6 @@ */ /********************* Autogenerated file. DO NOT EDIT ******************* - * Generated from {{script_name}} tool {% if file_name %}and {{file_name}}{% endif %} file. + * Generated from {{script_name}} tool {% if file_name %}and {{file_name}} {% endif %}file. * Date: {{gen_time}} ***************************************************************************/ diff --git a/scripts/py/use_case_resources.json b/scripts/py/use_case_resources.json new file mode 100644 index 0000000..80fa28d --- /dev/null +++ b/scripts/py/use_case_resources.json @@ -0,0 +1,190 @@ +[ + { + "name": "ad", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/" + ], + "resources": [ + { + "name": "ad_medium_int8.tflite", + "url": "{url_prefix:0}ad_medium_int8.tflite" + }, + {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, + {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"} + ] + }, + { + "name": "asr", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/" + ], + "resources": [ + { + "name": "wav2letter_pruned_int8.tflite", + "url": "{url_prefix:0}wav2letter_pruned_int8.tflite" + }, + { + "name": "ifm0.npy", + "url": "{url_prefix:0}testing_input/input_2_int8/0.npy" + }, + { + "name": "ofm0.npy", + "url": "{url_prefix:0}testing_output/Identity_int8/0.npy" + } + ] + }, + { + "name": "img_class", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/" + ], + "resources": [ + { + "name": "mobilenet_v2_1.0_224_INT8.tflite", + "url": "{url_prefix:0}mobilenet_v2_1.0_224_INT8.tflite" + }, + { + "name": "ifm0.npy", + "url": "{url_prefix:0}testing_input/tfl.quantize/0.npy" + }, + { + "name": "ofm0.npy", + "url": "{url_prefix:0}testing_output/MobilenetV2/Predictions/Reshape_11/0.npy" + } + ] + }, + { + "name": "object_detection", + "url_prefix": [ + "https://github.com/emza-vs/ModelZoo/blob/v1.0/object_detection/" + ], + "resources": [ + { + "name": "yolo-fastest_192_face_v4.tflite", + "url": "{url_prefix:0}yolo-fastest_192_face_v4.tflite?raw=true" + } + ] + }, + { + "name": "kws", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/" + ], + "resources": [ + {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, + {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}, + { + "name": "kws_micronet_m.tflite", + "url": "{url_prefix:0}kws_micronet_m.tflite" + } + ] + }, + { + "name": "vww", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/" + ], + "resources": [ + { + "name": "vww4_128_128_INT8.tflite", + "url": "{url_prefix:0}vww4_128_128_INT8.tflite" + }, + {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, + {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"} + ] + }, + { + "name": "kws_asr", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/", + "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/" + ], + "resources": [ + { + "name": "wav2letter_pruned_int8.tflite", + "url": "{url_prefix:0}wav2letter_pruned_int8.tflite" + }, + { + "sub_folder": "asr", + "name": "ifm0.npy", + "url": "{url_prefix:0}testing_input/input_2_int8/0.npy" + }, + { + "sub_folder": "asr", + "name": "ofm0.npy", + "url": "{url_prefix:0}testing_output/Identity_int8/0.npy" + }, + { + "sub_folder": "kws", + "name": "ifm0.npy", + "url": "{url_prefix:1}testing_input/input/0.npy" + }, + { + "sub_folder": "kws", + "name": "ofm0.npy", + "url": "{url_prefix:1}testing_output/Identity/0.npy" + }, + { + "name": "kws_micronet_m.tflite", + "url": "{url_prefix:1}kws_micronet_m.tflite" + } + ] + }, + { + "name": "noise_reduction", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/" + ], + "resources": [ + {"name": "rnnoise_INT8.tflite", "url": "{url_prefix:0}rnnoise_INT8.tflite"}, + { + "name": "ifm0.npy", + "url": "{url_prefix:0}testing_input/main_input_int8/0.npy" + }, + { + "name": "ifm1.npy", + "url": "{url_prefix:0}testing_input/vad_gru_prev_state_int8/0.npy" + }, + { + "name": "ifm2.npy", + "url": "{url_prefix:0}testing_input/noise_gru_prev_state_int8/0.npy" + }, + { + "name": "ifm3.npy", + "url": "{url_prefix:0}testing_input/denoise_gru_prev_state_int8/0.npy" + }, + { + "name": "ofm0.npy", + "url": "{url_prefix:0}testing_output/Identity_int8/0.npy" + }, + { + "name": "ofm1.npy", + "url": "{url_prefix:0}testing_output/Identity_1_int8/0.npy" + }, + { + "name": "ofm2.npy", + "url": "{url_prefix:0}testing_output/Identity_2_int8/0.npy" + }, + { + "name": "ofm3.npy", + "url": "{url_prefix:0}testing_output/Identity_3_int8/0.npy" + }, + { + "name": "ofm4.npy", + "url": "{url_prefix:0}testing_output/Identity_4_int8/0.npy" + } + ] + }, + { + "name": "inference_runner", + "url_prefix": [ + "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/" + ], + "resources": [ + { + "name": "dnn_s_quantized.tflite", + "url": "{url_prefix:0}dnn_s_quantized.tflite" + } + ] + } +] diff --git a/set_up_default_resources.py b/set_up_default_resources.py index f983508..a3987bc 100755 --- a/set_up_default_resources.py +++ b/set_up_default_resources.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Script to set up default resources for ML Embedded Evaluation Kit +""" +import dataclasses import errno import fnmatch import json @@ -22,207 +26,21 @@ import re import shutil import subprocess import sys +import typing import urllib.request import venv from argparse import ArgumentParser from argparse import ArgumentTypeError from collections import namedtuple -from urllib.error import URLError +from dataclasses import dataclass from pathlib import Path +from urllib.error import URLError from scripts.py.check_update_resources_downloaded import get_md5sum_for_file - -json_uc_res = [ - { - "use_case_name": "ad", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/" - ], - "resources": [ - { - "name": "ad_medium_int8.tflite", - "url": "{url_prefix:0}ad_medium_int8.tflite", - }, - {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, - {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}, - ], - }, - { - "use_case_name": "asr", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/" - ], - "resources": [ - { - "name": "wav2letter_pruned_int8.tflite", - "url": "{url_prefix:0}wav2letter_pruned_int8.tflite", - }, - { - "name": "ifm0.npy", - "url": "{url_prefix:0}testing_input/input_2_int8/0.npy", - }, - { - "name": "ofm0.npy", - "url": "{url_prefix:0}testing_output/Identity_int8/0.npy", - }, - ], - }, - { - "use_case_name": "img_class", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/" - ], - "resources": [ - { - "name": "mobilenet_v2_1.0_224_INT8.tflite", - "url": "{url_prefix:0}mobilenet_v2_1.0_224_INT8.tflite", - }, - { - "name": "ifm0.npy", - "url": "{url_prefix:0}testing_input/tfl.quantize/0.npy", - }, - { - "name": "ofm0.npy", - "url": "{url_prefix:0}testing_output/MobilenetV2/Predictions/Reshape_11/0.npy", - }, - ], - }, - { - "use_case_name": "object_detection", - "url_prefix": [ - "https://github.com/emza-vs/ModelZoo/blob/v1.0/object_detection/" - ], - "resources": [ - { - "name": "yolo-fastest_192_face_v4.tflite", - "url": "{url_prefix:0}yolo-fastest_192_face_v4.tflite?raw=true", - } - ], - }, - { - "use_case_name": "kws", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/" - ], - "resources": [ - {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, - {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}, - { - "name": "kws_micronet_m.tflite", - "url": "{url_prefix:0}kws_micronet_m.tflite", - }, - ], - }, - { - "use_case_name": "vww", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/" - ], - "resources": [ - { - "name": "vww4_128_128_INT8.tflite", - "url": "{url_prefix:0}vww4_128_128_INT8.tflite", - }, - {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"}, - {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}, - ], - }, - { - "use_case_name": "kws_asr", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/", - "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/", - ], - "resources": [ - { - "name": "wav2letter_pruned_int8.tflite", - "url": "{url_prefix:0}wav2letter_pruned_int8.tflite", - }, - { - "sub_folder": "asr", - "name": "ifm0.npy", - "url": "{url_prefix:0}testing_input/input_2_int8/0.npy", - }, - { - "sub_folder": "asr", - "name": "ofm0.npy", - "url": "{url_prefix:0}testing_output/Identity_int8/0.npy", - }, - { - "sub_folder": "kws", - "name": "ifm0.npy", - "url": "{url_prefix:1}testing_input/input/0.npy", - }, - { - "sub_folder": "kws", - "name": "ofm0.npy", - "url": "{url_prefix:1}testing_output/Identity/0.npy", - }, - { - "name": "kws_micronet_m.tflite", - "url": "{url_prefix:1}kws_micronet_m.tflite", - }, - ], - }, - { - "use_case_name": "noise_reduction", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/" - ], - "resources": [ - {"name": "rnnoise_INT8.tflite", "url": "{url_prefix:0}rnnoise_INT8.tflite"}, - { - "name": "ifm0.npy", - "url": "{url_prefix:0}testing_input/main_input_int8/0.npy", - }, - { - "name": "ifm1.npy", - "url": "{url_prefix:0}testing_input/vad_gru_prev_state_int8/0.npy", - }, - { - "name": "ifm2.npy", - "url": "{url_prefix:0}testing_input/noise_gru_prev_state_int8/0.npy", - }, - { - "name": "ifm3.npy", - "url": "{url_prefix:0}testing_input/denoise_gru_prev_state_int8/0.npy", - }, - { - "name": "ofm0.npy", - "url": "{url_prefix:0}testing_output/Identity_int8/0.npy", - }, - { - "name": "ofm1.npy", - "url": "{url_prefix:0}testing_output/Identity_1_int8/0.npy", - }, - { - "name": "ofm2.npy", - "url": "{url_prefix:0}testing_output/Identity_2_int8/0.npy", - }, - { - "name": "ofm3.npy", - "url": "{url_prefix:0}testing_output/Identity_3_int8/0.npy", - }, - { - "name": "ofm4.npy", - "url": "{url_prefix:0}testing_output/Identity_4_int8/0.npy", - }, - ], - }, - { - "use_case_name": "inference_runner", - "url_prefix": [ - "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/" - ], - "resources": [ - { - "name": "dnn_s_quantized.tflite", - "url": "{url_prefix:0}dnn_s_quantized.tflite", - } - ], - }, -] +# Supported version of Python and Vela +VELA_VERSION = "3.9.0" +py3_version_minimum = (3, 9) # Valid NPU configurations: valid_npu_config_names = [ @@ -250,10 +68,57 @@ NPUConfig = namedtuple( ], ) + +@dataclass(frozen=True) +class UseCaseResource: + """ + Represent a use case's resource + """ + name: str + url: str + sub_folder: typing.Optional[str] = None + + +@dataclass(frozen=True) +class UseCase: + """ + Represent a use case + """ + name: str + url_prefix: str + resources: typing.List[UseCaseResource] + + # The internal SRAM size for Corstone-300 implementation on MPS3 specified by AN552 # The internal SRAM size for Corstone-310 implementation on MPS3 specified by AN555 # is 4MB, but we are content with the 2MB specified below. -mps3_max_sram_sz = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each) +MPS3_MAX_SRAM_SZ = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each) + + +def load_use_case_resources(current_file_dir: Path) -> typing.List[UseCase]: + """ + Load use case metadata resources + + Parameters + ---------- + current_file_dir: Directory of the current script + + Returns + ------- + The use cases resources object parsed to a dict + """ + + resources_path = current_file_dir / "scripts" / "py" / "use_case_resources.json" + with open(resources_path, encoding="utf8") as f: + use_cases = json.load(f) + return [ + UseCase( + name=u["name"], + url_prefix=u["url_prefix"], + resources=[UseCaseResource(**r) for r in u["resources"]], + ) + for u in use_cases + ] def call_command(command: str, verbose: bool = True) -> str: @@ -266,21 +131,22 @@ def call_command(command: str, verbose: bool = True) -> str: """ if verbose: logging.info(command) - proc = subprocess.run( - command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True - ) - log = proc.stdout.decode("utf-8") - if proc.returncode == 0 and verbose: + try: + proc = subprocess.run( + command, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True + ) + log = proc.stdout.decode("utf-8") logging.info(log) - else: + return log + except subprocess.CalledProcessError as err: + log = err.stdout.decode("utf-8") logging.error(log) - proc.check_returncode() - return log + raise err def get_default_npu_config_from_name( - config_name: str, arena_cache_size: int = 0 -) -> NPUConfig: + config_name: str, arena_cache_size: int = 0 +) -> typing.Optional[NPUConfig]: """ Gets the file suffix for the TFLite file from the `accelerator_config` string. @@ -314,15 +180,15 @@ def get_default_npu_config_from_name( system_configs = ["Ethos_U55_High_End_Embedded", "Ethos_U65_High_End"] memory_modes_arena = { # For shared SRAM memory mode, we use the MPS3 SRAM size by default. - "Shared_Sram": mps3_max_sram_sz if arena_cache_size <= 0 else arena_cache_size, + "Shared_Sram": MPS3_MAX_SRAM_SZ if arena_cache_size <= 0 else arena_cache_size, # For dedicated SRAM memory mode, we do not override the arena size. This is expected to # be defined in the Vela configuration file instead. "Dedicated_Sram": None if arena_cache_size <= 0 else arena_cache_size, } - for i in range(len(strings_ids)): - if config_name.startswith(strings_ids[i]): - npu_config_id = config_name.replace(strings_ids[i], prefix_ids[i]) + for i, string_id in enumerate(strings_ids): + if config_name.startswith(string_id): + npu_config_id = config_name.replace(string_id, prefix_ids[i]) return NPUConfig( config_name=config_name, memory_mode=memory_modes[i], @@ -335,97 +201,316 @@ def get_default_npu_config_from_name( return None -def remove_tree_dir(dir_path): +def remove_tree_dir(dir_path: Path): + """ + Delete and re-create a directory + + Parameters + ---------- + dir_path : The directory path + """ try: # Remove the full directory. shutil.rmtree(dir_path) # Re-create an empty one. os.mkdir(dir_path) - except Exception as e: - logging.error(f"Failed to delete {dir_path}.") + except OSError: + logging.error("Failed to delete %s.", dir_path) -def set_up_resources( - run_vela_on_models: bool = False, - additional_npu_config_names: tuple = (), - arena_cache_size: int = 0, - check_clean_folder: bool = False, - additional_requirements_file: str = "") -> (Path, Path): +def initialize_use_case_resources_directory( + use_case: UseCase, + metadata: typing.Dict, + download_dir: Path, + check_clean_folder: bool, + setup_script_hash_verified: bool, +): """ - Helpers function that retrieve the output from a command. + Initialize the resources_downloaded directory for a use case - Parameters: - ---------- - run_vela_on_models (bool): Specifies if run vela on downloaded models. - additional_npu_config_names(list): list of strings of Ethos-U NPU configs. - arena_cache_size (int): Specifies arena cache size in bytes. If a value - greater than 0 is provided, this will be taken - as the cache size. If 0, the default values, as per - the NPU config requirements, are used. - check_clean_folder (bool): Indicates whether the resources folder needs to - be checked for updates and cleaned. - additional_requirements_file (str): Path to a requirements.txt file if - additional packages need to be - installed. + @param use_case: The use case + @param metadata: The metadata + @param download_dir: The parent directory + @param check_clean_folder: Whether to clean the folder + @param setup_script_hash_verified: Whether the hash of this script is verified + """ + try: + # Does the usecase_name download dir exist? + (download_dir / use_case.name).mkdir() + except OSError as err: + if err.errno == errno.EEXIST: + # The usecase_name download dir exist. + if check_clean_folder and not setup_script_hash_verified: + for idx, metadata_uc_url_prefix in enumerate( + [ + f + for f in metadata["resources_info"] + if f["name"] == use_case.name + ][0]["url_prefix"] + ): + if metadata_uc_url_prefix != use_case.url_prefix[idx]: + logging.info("Removing %s resources.", use_case.name) + remove_tree_dir(download_dir / use_case.name) + break + elif err.errno != errno.EEXIST: + logging.error("Error creating %s directory.", use_case.name) + raise - Returns - ------- - Tuple of pair of Paths: (download_directory_path, virtual_env_path) +def download_file(url: str, dest: Path): + """ + Download a file - download_directory_path: Root of the directory where the resources have been downloaded to. - virtual_env_path: Path to the root of virtual environment. + @param url: The URL of the file to download + @param dest: The destination of downloaded file """ - # Paths. - current_file_dir = Path(__file__).parent.resolve() - download_dir = current_file_dir / "resources_downloaded" - metadata_file_path = download_dir / "resources_downloaded_metadata.json" + try: + with urllib.request.urlopen(url) as g: + with open(dest, "b+w") as f: + f.write(g.read()) + logging.info("- Downloaded %s to %s.", url, dest) + except URLError: + logging.error("URLError while downloading %s.", url) + raise + + +def download_resources( + use_case: UseCase, + metadata: typing.Dict, + download_dir: Path, + check_clean_folder: bool, + setup_script_hash_verified: bool, +): + """ + Download the resources associated with a use case - metadata_dict = dict() - vela_version = "3.9.0" - py3_version_minimum = (3, 9) + @param use_case: The use case + @param metadata: The metadata + @param download_dir: The parent directory + @param check_clean_folder: Whether to clean the folder + @param setup_script_hash_verified: Whether the hash is already verified + """ + initialize_use_case_resources_directory( + use_case, + metadata, + download_dir, + check_clean_folder, + setup_script_hash_verified + ) - # Is Python minimum requirement matched? - py3_version = sys.version_info - if py3_version < py3_version_minimum: - raise Exception( - "ERROR: Python3.9+ is required, please see the documentation on how to update it." + reg_expr_str = r"{url_prefix:(.*\d)}" + reg_expr_pattern = re.compile(reg_expr_str) + for res in use_case.resources: + res_name = res.name + url_prefix_idx = int(reg_expr_pattern.search(res.url).group(1)) + res_url = use_case.url_prefix[url_prefix_idx] + re.sub( + reg_expr_str, "", res.url + ) + + sub_folder = "" + if res.sub_folder is not None: + try: + # Does the usecase_name/sub_folder download dir exist? + (download_dir / use_case.name / res.sub_folder).mkdir() + except OSError as err: + if err.errno != errno.EEXIST: + logging.error( + "Error creating %s/%s directory.", + use_case.name, + res.sub_folder + ) + raise + sub_folder = res.sub_folder + + res_dst = download_dir / use_case.name / sub_folder / res_name + + if res_dst.is_file(): + logging.info("File %s exists, skipping download.", res_dst) + else: + download_file(res_url, res_dst) + + +def run_vela( + config: NPUConfig, + env_activate_cmd: str, + model: Path, + config_file: Path, + output_dir: Path +) -> bool: + """ + Run vela on the specified model + @param config: The NPU configuration + @param env_activate_cmd: The Python venv activation command + @param model: The model + @param config_file: The vela config file + @param output_dir: The output directory + @return: True if the optimisation was skipped, false otherwise + """ + # model name after compiling with vela is an initial model name + _vela suffix + vela_optimised_model_path = model.parent / (model.stem + "_vela.tflite") + + vela_command_arena_cache_size = "" + + if config.arena_cache_size: + vela_command_arena_cache_size = ( + f"--arena-cache-size={config.arena_cache_size}" + ) + + vela_command = ( + f"{env_activate_cmd} && vela {model} " + + f"--accelerator-config={config.config_name} " + + "--optimise Performance " + + f"--config {config_file} " + + f"--memory-mode={config.memory_mode} " + + f"--system-config={config.system_config} " + + f"--output-dir={output_dir} " + + f"{vela_command_arena_cache_size}" + ) + + # We want the name to include the configuration suffix. For example: vela_H128, + # vela_Y512 etc. + new_suffix = "_vela_" + config.ethos_u_config_id + ".tflite" + new_vela_optimised_model_path = model.parent / (model.stem + new_suffix) + + skip_optimisation = new_vela_optimised_model_path.is_file() + + if skip_optimisation: + logging.info( + "File %s exists, skipping optimisation.", + new_vela_optimised_model_path ) else: - logging.info(f"Using Python version: {py3_version}") + call_command(vela_command) + + # Rename default vela model. + vela_optimised_model_path.rename(new_vela_optimised_model_path) + logging.info( + "Renaming %s to %s.", + vela_optimised_model_path, + new_vela_optimised_model_path + ) + return skip_optimisation + + +def run_vela_on_all_models( + current_file_dir: Path, + download_dir: Path, + env_activate_cmd: str, + arena_cache_size: int, + npu_config_names: typing.List[str] +): + """ + Run vela on downloaded models for the specified NPU configurations + + @param current_file_dir: Path to the current directory + @param download_dir: Path to the downloaded resources directory + @param env_activate_cmd: Command used to activate Python venv + @param npu_config_names: Names of NPU configurations for which to run Vela + @param arena_cache_size: The arena cache size + """ + config_file = current_file_dir / "scripts" / "vela" / "default_vela.ini" + models = [ + Path(dirpath) / f + for dirpath, dirnames, files in os.walk(download_dir) + for f in fnmatch.filter(files, "*.tflite") + if "vela" not in f + ] + + # Get npu config tuple for each config name in a list: + npu_configs = [ + get_default_npu_config_from_name(name, arena_cache_size) + for name in npu_config_names + ] + + logging.info("All models will be optimised for these configs:") + for config in npu_configs: + logging.info(config) + + optimisation_skipped = False + + for model in models: + for config in npu_configs: + optimisation_skipped = run_vela( + config, + env_activate_cmd, + model, + config_file, + output_dir=model.parent + ) or optimisation_skipped + + # If any optimisation was skipped, show how to regenerate: + if optimisation_skipped: + logging.warning("One or more optimisations were skipped.") + logging.warning( + "To optimise all the models, please remove the directory %s.", + download_dir + ) + + +def initialize_resources_directory( + download_dir: Path, + check_clean_folder: bool, + metadata_file_path: Path, + setup_script_hash: str +) -> typing.Tuple[typing.Dict, bool]: + """ + Sets up the resources_downloaded directory and checks to see if this script + has been modified since the last time resources were downloaded + + @param download_dir: Path to the resources_downloaded directory + @param check_clean_folder: Determines whether to clean the downloads directory + @param metadata_file_path: Path to the metadata file + @param setup_script_hash: The md5 hash of this script + @return: The metadata and a boolean to indicate whether this + script has changed since it was last run + """ + metadata_dict = {} setup_script_hash_verified = False - setup_script_hash = get_md5sum_for_file(Path(__file__).resolve()) - try: - # 1.1 Does the download dir exist? + if download_dir.is_dir(): + logging.info("'resources_downloaded' directory exists.") + # Check and clean? + if check_clean_folder and metadata_file_path.is_file(): + with open(metadata_file_path, encoding="utf8") as metadata_file: + metadata_dict = json.load(metadata_file) + + vela_in_metadata = metadata_dict["ethosu_vela_version"] + if vela_in_metadata != VELA_VERSION: + # Check if all the resources needs to be removed and regenerated. + # This can happen when the Vela version has changed. + logging.info( + ("Vela version in metadata is %s, current %s." + " Removing the resources and re-download them.", + vela_in_metadata, + VELA_VERSION + ) + ) + remove_tree_dir(download_dir) + metadata_dict = {} + else: + # Check if the set_up_default_resorces.py has changed from last setup + setup_script_hash_verified = ( + metadata_dict.get("set_up_script_md5sum") + == setup_script_hash + ) + else: download_dir.mkdir() - except OSError as e: - if e.errno == errno.EEXIST: - logging.info("'resources_downloaded' directory exists.") - # Check and clean? - if check_clean_folder and metadata_file_path.is_file(): - with open(metadata_file_path) as metadata_file: - metadata_dict = json.load(metadata_file) - vela_in_metadata = metadata_dict["ethosu_vela_version"] - if vela_in_metadata != vela_version: - # Check if all the resources needs to be removed and regenerated. - # This can happen when the Vela version has changed. - logging.info( - f"Vela version in metadata is {vela_in_metadata}, current {vela_version}. Removing the resources and re-download them." - ) - remove_tree_dir(download_dir) - metadata_dict = dict() - else: - # Check if the set_up_default_resorces.py has changed from last setup - setup_script_hash_verified = ( - metadata_dict.get("set_up_script_md5sum") - == setup_script_hash - ) - else: - raise - # 1.2 Does the virtual environment exist? + return metadata_dict, setup_script_hash_verified + + +def set_up_python_venv( + download_dir: Path, + additional_requirements_file: Path = "" +): + """ + Set up the Python environment with which to set up the resources + + @param download_dir: Path to the resources_downloaded directory + @param additional_requirements_file: Optional additional requirements file + @return: Path to the venv Python binary + activate command + """ env_dirname = "env" env_path = download_dir / env_dirname @@ -434,23 +519,28 @@ def set_up_resources( env_python = Path(venv_context.env_exe) - if sys.platform == "win32": - env_activate = f"{venv_context.bin_path}/activate.bat" - else: - env_activate = f". {venv_context.bin_path}/activate" - if not env_python.is_file(): # Create the virtual environment using current interpreter's venv # (not necessarily the system's Python3) venv_builder.create(env_dir=env_path) + if sys.platform == "win32": + env_activate = Path(f"{venv_context.bin_path}/activate.bat") + env_activate_cmd = str(env_activate) + else: + env_activate = Path(f"{venv_context.bin_path}/activate") + env_activate_cmd = f". {env_activate}" + + if not env_activate.is_file(): + venv_builder.install_scripts(venv_context, venv_context.bin_path) + # 1.3 Install additional requirements first, if a valid file has been provided if additional_requirements_file and os.path.isfile(additional_requirements_file): command = f"{env_python} -m pip install -r {additional_requirements_file}" call_command(command) # 1.4 Make sure to have all the main requirements - requirements = [f"ethos-u-vela=={vela_version}"] + requirements = [f"ethos-u-vela=={VELA_VERSION}"] command = f"{env_python} -m pip freeze" packages = call_command(command) for req in requirements: @@ -458,162 +548,130 @@ def set_up_resources( command = f"{env_python} -m pip install {req}" call_command(command) - # 2. Download models - logging.info("Downloading resources.") - for uc in json_uc_res: - use_case_name = uc["use_case_name"] - res_url_prefix = uc["url_prefix"] - try: - # Does the usecase_name download dir exist? - (download_dir / use_case_name).mkdir() - except OSError as e: - if e.errno == errno.EEXIST: - # The usecase_name download dir exist. - if check_clean_folder and not setup_script_hash_verified: - for idx, metadata_uc_url_prefix in enumerate( - [ - f - for f in metadata_dict["resources_info"] - if f["use_case_name"] == use_case_name - ][0]["url_prefix"] - ): - if metadata_uc_url_prefix != res_url_prefix[idx]: - logging.info(f"Removing {use_case_name} resources.") - remove_tree_dir(download_dir / use_case_name) - break - elif e.errno != errno.EEXIST: - logging.error(f"Error creating {use_case_name} directory.") - raise - - reg_expr_str = r"{url_prefix:(.*\d)}" - reg_expr_pattern = re.compile(reg_expr_str) - for res in uc["resources"]: - res_name = res["name"] - url_prefix_idx = int(reg_expr_pattern.search(res["url"]).group(1)) - res_url = res_url_prefix[url_prefix_idx] + re.sub( - reg_expr_str, "", res["url"] - ) + return env_path, env_activate_cmd - sub_folder = "" - if "sub_folder" in res: - try: - # Does the usecase_name/sub_folder download dir exist? - (download_dir / use_case_name / res["sub_folder"]).mkdir() - except OSError as e: - if e.errno != errno.EEXIST: - logging.error( - f"Error creating {use_case_name} / {res['sub_folder']} directory." - ) - raise - sub_folder = res["sub_folder"] - - res_dst = download_dir / use_case_name / sub_folder / res_name - - if res_dst.is_file(): - logging.info(f"File {res_dst} exists, skipping download.") - else: - try: - g = urllib.request.urlopen(res_url) - with open(res_dst, "b+w") as f: - f.write(g.read()) - logging.info(f"- Downloaded {res_url} to {res_dst}.") - except URLError: - logging.error(f"URLError while downloading {res_url}.") - raise - # 3. Run vela on models in resources_downloaded - # New models will have same name with '_vela' appended. - # For example: - # original model: kws_micronet_m.tflite - # after vela model: kws_micronet_m_vela_H128.tflite - # - # Note: To avoid to run vela twice on the same model, it's supposed that - # downloaded model names don't contain the 'vela' word. - if run_vela_on_models is True: - config_file = current_file_dir / "scripts" / "vela" / "default_vela.ini" - models = [ - Path(dirpath) / f - for dirpath, dirnames, files in os.walk(download_dir) - for f in fnmatch.filter(files, "*.tflite") - if "vela" not in f - ] +def update_metadata( + metadata_dict: typing.Dict, + setup_script_hash: str, + json_uc_res: typing.List[UseCase], + metadata_file_path: Path +): + """ + Update the metadata file - # Consolidate all config names while discarding duplicates: - config_names = list(set(default_npu_config_names + additional_npu_config_names)) + @param metadata_dict: The metadata dictionary to update + @param setup_script_hash: The setup script hash + @param json_uc_res: The use case resources metadata + @param metadata_file_path The metadata file path + """ + metadata_dict["ethosu_vela_version"] = VELA_VERSION + metadata_dict["set_up_script_md5sum"] = setup_script_hash.strip("\n") + metadata_dict["resources_info"] = [dataclasses.asdict(uc) for uc in json_uc_res] - # Get npu config tuple for each config name in a list: - npu_configs = [ - get_default_npu_config_from_name(name, arena_cache_size) - for name in config_names - ] + with open(metadata_file_path, "w", encoding="utf8") as metadata_file: + json.dump(metadata_dict, metadata_file, indent=4) - logging.info(f"All models will be optimised for these configs:") - for config in npu_configs: - logging.info(config) - optimisation_skipped = False +def set_up_resources( + run_vela_on_models: bool = False, + additional_npu_config_names: tuple = (), + arena_cache_size: int = 0, + check_clean_folder: bool = False, + additional_requirements_file: Path = "" +) -> Path: + """ + Helpers function that retrieve the output from a command. - for model in models: - output_dir = model.parent - # model name after compiling with vela is an initial model name + _vela suffix - vela_optimised_model_path = model.parent / (model.stem + "_vela.tflite") + Parameters: + ---------- + run_vela_on_models (bool): Specifies if run vela on downloaded models. + additional_npu_config_names(list): list of strings of Ethos-U NPU configs. + arena_cache_size (int): Specifies arena cache size in bytes. If a value + greater than 0 is provided, this will be taken + as the cache size. If 0, the default values, as per + the NPU config requirements, are used. + check_clean_folder (bool): Indicates whether the resources folder needs to + be checked for updates and cleaned. + additional_requirements_file (str): Path to a requirements.txt file if + additional packages need to be + installed. - for config in npu_configs: - vela_command_arena_cache_size = "" + Returns + ------- - if config.arena_cache_size: - vela_command_arena_cache_size = ( - f"--arena-cache-size={config.arena_cache_size}" - ) + Tuple of pair of Paths: (download_directory_path, virtual_env_path) - vela_command = ( - f"{env_activate} && vela {model} " - + f"--accelerator-config={config.config_name} " - + "--optimise Performance " - + f"--config {config_file} " - + f"--memory-mode={config.memory_mode} " - + f"--system-config={config.system_config} " - + f"--output-dir={output_dir} " - + f"{vela_command_arena_cache_size}" - ) + download_directory_path: Root of the directory where the resources have been downloaded to. + virtual_env_path: Path to the root of virtual environment. + """ + # Paths. + current_file_dir = Path(__file__).parent.resolve() + download_dir = current_file_dir / "resources_downloaded" + metadata_file_path = download_dir / "resources_downloaded_metadata.json" - # We want the name to include the configuration suffix. For example: vela_H128, - # vela_Y512 etc. - new_suffix = "_vela_" + config.ethos_u_config_id + ".tflite" - new_vela_optimised_model_path = model.parent / (model.stem + new_suffix) + # Is Python minimum requirement matched? + if sys.version_info < py3_version_minimum: + raise RuntimeError( + f"ERROR: Python{'.'.join(str(i) for i in py3_version_minimum)}+ is required," + f" please see the documentation on how to update it." + ) + logging.info("Using Python version: %s", sys.version_info) - if new_vela_optimised_model_path.is_file(): - logging.info( - f"File {new_vela_optimised_model_path} exists, skipping optimisation." - ) - optimisation_skipped = True - continue + json_uc_res = load_use_case_resources(current_file_dir) + setup_script_hash = get_md5sum_for_file(Path(__file__).resolve()) - call_command(vela_command) + metadata_dict, setup_script_hash_verified = initialize_resources_directory( + download_dir, + check_clean_folder, + metadata_file_path, + setup_script_hash + ) - # Rename default vela model. - vela_optimised_model_path.rename(new_vela_optimised_model_path) - logging.info( - f"Renaming {vela_optimised_model_path} to {new_vela_optimised_model_path}." - ) + env_path, env_activate = set_up_python_venv( + download_dir, + additional_requirements_file + ) - # If any optimisation was skipped, show how to regenerate: - if optimisation_skipped: - logging.warning("One or more optimisations were skipped.") - logging.warning( - f"To optimise all the models, please remove the directory {download_dir}." - ) + # 2. Download models + logging.info("Downloading resources.") + for use_case in json_uc_res: + download_resources( + use_case, + metadata_dict, + download_dir, + check_clean_folder, + setup_script_hash_verified + ) + + # 3. Run vela on models in resources_downloaded + # New models will have same name with '_vela' appended. + # For example: + # original model: kws_micronet_m.tflite + # after vela model: kws_micronet_m_vela_H128.tflite + # + # Note: To avoid to run vela twice on the same model, it's supposed that + # downloaded model names don't contain the 'vela' word. + if run_vela_on_models is True: + # Consolidate all config names while discarding duplicates: + run_vela_on_all_models( + current_file_dir, + download_dir, + env_activate, + arena_cache_size, + npu_config_names=list(set(default_npu_config_names + list(additional_npu_config_names))) + ) # 4. Collect and write metadata logging.info("Collecting and write metadata.") - metadata_dict["ethosu_vela_version"] = vela_version - metadata_dict["set_up_script_md5sum"] = setup_script_hash.strip("\n") - metadata_dict["resources_info"] = json_uc_res - - with open(metadata_file_path, "w") as metadata_file: - json.dump(metadata_dict, metadata_file, indent=4) + update_metadata( + metadata_dict, + setup_script_hash.strip("\n"), + json_uc_res, + metadata_file_path + ) - return download_dir, env_path + return env_path if __name__ == "__main__": -- cgit v1.2.1