summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Tawse <alex.tawse@arm.com>2023-09-29 15:55:38 +0100
committerRichard <richard.burton@arm.com>2023-10-26 12:35:48 +0000
commitdaba3cf2e3633cbd0e4f8aabe7578b97e88deee1 (patch)
tree51024b8025e28ecb2aecd67246e189e25f5a6e6c
parenta11976fb866f77305708f832e603b963969e6a14 (diff)
downloadml-embedded-evaluation-kit-daba3cf2e3633cbd0e4f8aabe7578b97e88deee1.tar.gz
MLECO-3995: Pylint + Shellcheck compatibility
* All Python scripts updated to abide by Pylint rules * good-names updated to permit short variable names: i, j, k, f, g, ex * ignore-long-lines regex updated to allow long lines for licence headers * Shell scripts now compliant with Shellcheck Signed-off-by: Alex Tawse <Alex.Tawse@arm.com> Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5
-rw-r--r--.pylintrc653
-rwxr-xr-xbuild_default.py261
-rwxr-xr-xdownload_dependencies.py96
-rw-r--r--model_conditioning_examples/post_training_quantization.py61
-rw-r--r--model_conditioning_examples/quantization_aware_training.py68
-rw-r--r--model_conditioning_examples/setup.sh9
-rw-r--r--model_conditioning_examples/training_utils.py5
-rw-r--r--model_conditioning_examples/weight_clustering.py87
-rw-r--r--model_conditioning_examples/weight_pruning.py75
-rw-r--r--scripts/py/check_update_resources_downloaded.py54
-rw-r--r--scripts/py/dependency_urls.json8
-rw-r--r--scripts/py/gen_audio.py107
-rw-r--r--scripts/py/gen_audio_cpp.py258
-rw-r--r--scripts/py/gen_default_input_cpp.py49
-rw-r--r--scripts/py/gen_labels_cpp.py74
-rw-r--r--scripts/py/gen_model_cpp.py89
-rw-r--r--scripts/py/gen_rgb_cpp.py203
-rw-r--r--scripts/py/gen_test_data_cpp.py317
-rw-r--r--scripts/py/gen_utils.py194
-rwxr-xr-xscripts/py/git_pre_push_hooks.sh48
-rw-r--r--scripts/py/rnnoise_dump_extractor.py79
-rw-r--r--scripts/py/setup_hooks.py109
-rw-r--r--scripts/py/templates/header_template.txt2
-rw-r--r--scripts/py/use_case_resources.json190
-rwxr-xr-xset_up_default_resources.py898
25 files changed, 2914 insertions, 1080 deletions
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000..b6cb7ee
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,653 @@
+# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10.0
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+ignore-paths=
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.9
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# Add paths to the list of the source roots. Supports globbing patterns. The
+# source root is an absolute path or a path relative to the current working
+# directory used to determine a package namespace for modules located under the
+# source root.
+source-roots=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ f,
+ g,
+ x,
+ y,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type alias names. If left empty, type
+# alias names will be checked with the set naming style.
+#typealias-rgx=
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ asyncSetUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )?<?https?://\S+>?$|^# SPDX-FileCopyrightText.*$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-implicit-booleaness-not-comparison-to-string,
+ use-implicit-booleaness-not-comparison-to-zero,
+ use-symbolic-message-instead
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are: text, parseable, colorized,
+# json2 (improved json format), json (old json format) and msvs (visual
+# studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+#output-format=
+
+# Tells whether to display a full report or only the messages.
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. No available dictionaries : You need to install
+# both the python package and the system dependency for enchant to work.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/build_default.py b/build_default.py
index 1d562f9..907bf4d 100755
--- a/build_default.py
+++ b/build_default.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Script to build the ML Embedded Evaluation kit using default configuration
+"""
import logging
import multiprocessing
import os
@@ -22,6 +25,7 @@ import sys
import threading
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
+from collections import namedtuple
from pathlib import Path
from set_up_default_resources import default_npu_config_names
@@ -29,66 +33,178 @@ from set_up_default_resources import get_default_npu_config_from_name
from set_up_default_resources import set_up_resources
from set_up_default_resources import valid_npu_config_names
+BuildArgs = namedtuple(
+ "BuildArgs",
+ [
+ "toolchain",
+ "download_resources",
+ "run_vela_on_models",
+ "npu_config_name",
+ "make_jobs",
+ "make_verbose",
+ ],
+)
+
class PipeLogging(threading.Thread):
+ """
+ Class used to log stdout from subprocesses
+ """
+
def __init__(self, log_level):
threading.Thread.__init__(self)
- self.logLevel = log_level
- self.fileRead, self.fileWrite = os.pipe()
- self.pipeIn = os.fdopen(self.fileRead)
+ self.log_level = log_level
+ self.file_read, self.file_write = os.pipe()
+ self.pipe_in = os.fdopen(self.file_read)
self.daemon = False
self.start()
def fileno(self):
- return self.fileWrite
+ """
+ Get self.file_write
+
+ Returns
+ -------
+ self.file_write
+ """
+ return self.file_write
def run(self):
- for line in iter(self.pipeIn.readline, ""):
- logging.log(self.logLevel, line.strip("\n"))
+ """
+ Log the contents of self.pipe_in
+ """
+ for line in iter(self.pipe_in.readline, ""):
+ logging.log(self.log_level, line.strip("\n"))
- self.pipeIn.close()
+ self.pipe_in.close()
def close(self):
- os.close(self.fileWrite)
+ """
+ Close the pipe
+ """
+ os.close(self.file_write)
+
+
+def get_toolchain_file_name(toolchain: str) -> str:
+ """
+ Get the name of the toolchain file for the toolchain.
+
+ Parameters
+ ----------
+ toolchain : name of the specified toolchain
+
+ Returns
+ -------
+ name of the toolchain file corresponding to the specified
+ toolchain
+ """
+ if toolchain == "arm":
+ return "bare-metal-armclang.cmake"
+
+ if toolchain == "gnu":
+ return "bare-metal-gcc.cmake"
+
+ raise ValueError("Toolchain must be one of: gnu, arm")
+
+
+def prep_build_dir(
+ current_file_dir: Path,
+ target_platform: str,
+ target_subsystem: str,
+ npu_config_name: str,
+ toolchain: str
+) -> Path:
+ """
+ Create or clean the build directory for this project.
+
+ Parameters
+ ----------
+ current_file_dir : The current directory of the running script
+ target_platform : The name of the target platform, e.g. "mps3"
+ target_subsystem: : The name of the target subsystem, e.g. "sse-300"
+ npu_config_name : The NPU config name, e.g. "ethos-u55-32"
+ toolchain : The name of the specified toolchain, e.g."arm"
+
+ Returns
+ -------
+ The path to the build directory
+ """
+ build_dir = (
+ current_file_dir /
+ f"cmake-build-{target_platform}-{target_subsystem}-{npu_config_name}-{toolchain}"
+ )
+
+ try:
+ build_dir.mkdir()
+ except FileExistsError:
+ # Directory already exists, clean it.
+ for filepath in build_dir.iterdir():
+ try:
+ if filepath.is_file() or filepath.is_symlink():
+ filepath.unlink()
+ elif filepath.is_dir():
+ shutil.rmtree(filepath)
+ except OSError as err:
+ logging.error("Failed to delete %s. Reason: %s", filepath, err)
+
+ return build_dir
-def run(
- toolchain: str,
- download_resources: bool,
- run_vela_on_models: bool,
- npu_config_name: str,
- make_jobs: int,
- make_verbose: bool,
+def run_command(
+ command: str,
+ logpipe: PipeLogging,
+ fail_message: str
):
"""
+ Run a command and exit upon failure.
+
+ Parameters
+ ----------
+ command : The command to run
+ logpipe : The PipeLogging object to capture stdout
+ fail_message : The message to log upon a non-zero exit code
+ """
+ logging.info("\n\n\n%s\n\n\n", command)
+
+ try:
+ subprocess.run(
+ command, check=True, shell=True, stdout=logpipe, stderr=subprocess.STDOUT
+ )
+ except subprocess.CalledProcessError as err:
+ logging.error(fail_message)
+ logpipe.close()
+ sys.exit(err.returncode)
+
+
+def run(args: BuildArgs):
+ """
Run the helpers scripts.
Parameters:
----------
- toolchain (str) : Specifies if 'gnu' or 'arm' toolchain needs to be used.
- download_resources (bool): Specifies if 'Download resources' step is performed.
- run_vela_on_models (bool): Only if `download_resources` is True, specifies if run vela on downloaded models.
- npu_config_name(str) : Ethos-U NPU configuration name. See "valid_npu_config_names"
+ args (BuildArgs) : Parsed set of build args expecting:
+ - toolchain
+ - download_resources
+ - run_vela_on_models
+ - np_config_name
+ toolchain (str) : Specifies if 'gnu' or 'arm' toolchain needs to be used.
+ download_resources (bool) : Specifies if 'Download resources' step is performed.
+ run_vela_on_models (bool) : Only if `download_resources` is True, specifies if
+ run vela on downloaded models.
+ npu_config_name(str) : Ethos-U NPU configuration name. See "valid_npu_config_names"
"""
current_file_dir = Path(__file__).parent.resolve()
# 1. Make sure the toolchain is supported, and set the right one here
- supported_toolchain_ids = ["gnu", "arm"]
- assert (
- toolchain in supported_toolchain_ids
- ), f"Toolchain must be from {supported_toolchain_ids}"
- if toolchain == "arm":
- toolchain_file_name = "bare-metal-armclang.cmake"
- elif toolchain == "gnu":
- toolchain_file_name = "bare-metal-gcc.cmake"
+ toolchain_file_name = get_toolchain_file_name(args.toolchain)
# 2. Download models if specified
- if download_resources is True:
+ if args.download_resources is True:
logging.info("Downloading resources.")
- (download_dir, env_path) = set_up_resources(
- run_vela_on_models=run_vela_on_models,
- additional_npu_config_names=[npu_config_name],
+ env_path = set_up_resources(
+ run_vela_on_models=args.run_vela_on_models,
+ additional_npu_config_names=(args.npu_config_name,),
additional_requirements_file=current_file_dir / "scripts" / "py" / "requirements.txt"
)
@@ -96,57 +212,42 @@ def run(
logging.info("Building default configuration.")
target_platform = "mps3"
target_subsystem = "sse-300"
- ethos_u_cfg = get_default_npu_config_from_name(npu_config_name)
- build_dir = current_file_dir / f"cmake-build-{target_platform}-{target_subsystem}-{npu_config_name}-{toolchain}"
- try:
- build_dir.mkdir()
- except FileExistsError:
- # Directory already exists, clean it.
- for filepath in build_dir.iterdir():
- try:
- if filepath.is_file() or filepath.is_symlink():
- filepath.unlink()
- elif filepath.is_dir():
- shutil.rmtree(filepath)
- except Exception as e:
- logging.error(f"Failed to delete {filepath}. Reason: {e}")
+ build_dir = prep_build_dir(
+ current_file_dir,
+ target_platform,
+ target_subsystem,
+ args.npu_config_name,
+ args.toolchain
+ )
logpipe = PipeLogging(logging.INFO)
- cmake_toolchain_file = current_file_dir / "scripts" / "cmake" / "toolchains" / toolchain_file_name
+ cmake_toolchain_file = (
+ current_file_dir /
+ "scripts" /
+ "cmake" /
+ "toolchains" /
+ toolchain_file_name
+ )
+ ethos_u_cfg = get_default_npu_config_from_name(args.npu_config_name)
cmake_path = env_path / "bin" / "cmake"
cmake_command = (
f"{cmake_path} -B {build_dir} -DTARGET_PLATFORM={target_platform}"
- + f" -DTARGET_SUBSYSTEM={target_subsystem}"
- + f" -DCMAKE_TOOLCHAIN_FILE={cmake_toolchain_file}"
- + f" -DETHOS_U_NPU_ID={ethos_u_cfg.ethos_u_npu_id}"
- + f" -DETHOS_U_NPU_CONFIG_ID={ethos_u_cfg.ethos_u_config_id}"
- + f" -DTENSORFLOW_LITE_MICRO_CLEAN_DOWNLOADS=ON"
- )
-
- logging.info(f"\n\n\n{cmake_command}\n\n\n")
- state = subprocess.run(
- cmake_command, shell=True, stdout=logpipe, stderr=subprocess.STDOUT
+ f" -DTARGET_SUBSYSTEM={target_subsystem}"
+ f" -DCMAKE_TOOLCHAIN_FILE={cmake_toolchain_file}"
+ f" -DETHOS_U_NPU_ID={ethos_u_cfg.ethos_u_npu_id}"
+ f" -DETHOS_U_NPU_CONFIG_ID={ethos_u_cfg.ethos_u_config_id}"
+ " -DTENSORFLOW_LITE_MICRO_CLEAN_DOWNLOADS=ON"
)
- if state.returncode != 0:
- logging.error("Failed to configure the project.")
- logpipe.close()
- sys.exit(state.returncode)
+ run_command(cmake_command, logpipe, fail_message="Failed to configure the project.")
- make_command = f"{cmake_path} --build {build_dir} -j{make_jobs}"
- if make_verbose:
+ make_command = f"{cmake_path} --build {build_dir} -j{args.make_jobs}"
+ if args.make_verbose:
make_command += " --verbose"
- logging.info(f"\n\n\n{make_command}\n\n\n")
- state = subprocess.run(
- make_command, shell=True, stdout=logpipe, stderr=subprocess.STDOUT
- )
- if state.returncode != 0:
- logging.error("Failed to build project.")
- logpipe.close()
- sys.exit(state.returncode)
+ run_command(make_command, logpipe, fail_message="Failed to build project.")
logpipe.close()
@@ -185,18 +286,20 @@ if __name__ == "__main__":
parser.add_argument(
"--make-verbose", help="Make runs with VERBOSE=1", action="store_true"
)
- args = parser.parse_args()
+ parsed_args = parser.parse_args()
logging.basicConfig(
filename="log_build_default.log", level=logging.DEBUG, filemode="w"
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
- run(
- args.toolchain.lower(),
- not args.skip_download,
- not args.skip_vela,
- args.npu_config_name,
- args.make_jobs,
- args.make_verbose,
+ build_args = BuildArgs(
+ toolchain=parsed_args.toolchain.lower(),
+ download_resources=not parsed_args.skip_download,
+ run_vela_on_models=not parsed_args.skip_vela,
+ npu_config_name=parsed_args.npu_config_name,
+ make_jobs=parsed_args.make_jobs,
+ make_verbose=parsed_args.make_verbose
)
+
+ run(build_args)
diff --git a/download_dependencies.py b/download_dependencies.py
index 3934f94..786fa4c 100755
--- a/download_dependencies.py
+++ b/download_dependencies.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
-
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,69 +15,92 @@
# limitations under the License.
"""This script does effectively the same as "git submodule update --init" command."""
+import json
import logging
import sys
import tarfile
import tempfile
+import typing
+from pathlib import Path
from urllib.request import urlopen
from zipfile import ZipFile
-from pathlib import Path
-TF = "https://github.com/tensorflow/tflite-micro/archive/568d181ccc1f60e49742fd43b7f97141ee8d45fc.zip"
-CMSIS = "https://github.com/ARM-software/CMSIS_5/archive/a75f01746df18bb5b929dfb8dc6c9407fac3a0f3.zip"
-CMSIS_DSP = "https://github.com/ARM-software/CMSIS-DSP/archive/refs/tags/v1.15.0.zip"
-CMSIS_NN = "https://github.com/ARM-software/CMSIS-NN/archive/refs/85164a811917770d7027a12a57ed3b469dac6537.zip"
-ETHOS_U_CORE_DRIVER = "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/snapshot/ethos-u-core-driver-23.08.tar.gz"
-ETHOS_U_CORE_PLATFORM = "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-23.08.tar.gz"
+def download(
+ url_file: str,
+ to_path: Path,
+):
+ """
+ Download a file from the specified URL
-def download(url_file: str, post_process=None):
+ @param url_file: The URL of the file to download
+ @param to_path: The location to download the file to
+ """
with urlopen(url_file) as response, tempfile.NamedTemporaryFile() as temp:
- logging.info(f"Downloading {url_file} ...")
+ logging.info("Downloading %s ...", url_file)
temp.write(response.read())
temp.seek(0)
- logging.info(f"Finished downloading {url_file}.")
- if post_process:
- post_process(temp)
-
-
-def unzip(file, to_path):
- with ZipFile(file) as z:
- for archive_path in z.infolist():
+ logging.info("Finished downloading %s.", url_file)
+ if url_file.endswith(".tar.gz"):
+ untar(temp, to_path)
+ else:
+ unzip(temp, to_path)
+
+
+def unzip(
+ file: typing.IO[bytes],
+ to_path: Path
+):
+ """
+ Unzip the specified file
+
+ @param file: The file to unzip
+ @param to_path: The location to extract to
+ """
+ with ZipFile(file) as f:
+ for archive_path in f.infolist():
archive_path.filename = archive_path.filename[archive_path.filename.find("/") + 1:]
if archive_path.filename:
- z.extract(archive_path, to_path)
+ f.extract(archive_path, to_path)
target_path = to_path / archive_path.filename
attr = archive_path.external_attr >> 16
if attr != 0:
target_path.chmod(attr)
-def untar(file, to_path):
- with tarfile.open(file) as z:
- for archive_path in z.getmembers():
+def untar(
+ file: bytes,
+ to_path: Path
+):
+ """
+ Untar the specified file
+
+ @param file: The file to untar
+ @param to_path: The location to extract to
+ """
+ with tarfile.open(file) as f:
+ for archive_path in f.getmembers():
index = archive_path.name.find("/")
if index < 0:
continue
archive_path.name = archive_path.name[index + 1:]
if archive_path.name:
- z.extract(archive_path, to_path)
+ f.extract(archive_path, to_path)
def main(dependencies_path: Path):
+ """
+ Download all dependencies
+
+ @param dependencies_path: The path to which the dependencies will be downloaded
+ """
+ dependency_urls_path = (
+ Path(__file__).parent.resolve() / "scripts" / "py" / "dependency_urls.json")
+ with open(dependency_urls_path, encoding="utf8") as f:
+ dependency_urls = json.load(f)
- download(CMSIS,
- lambda file: unzip(file.name, to_path=dependencies_path / "cmsis"))
- download(CMSIS_DSP,
- lambda file: unzip(file.name, to_path=dependencies_path / "cmsis-dsp"))
- download(CMSIS_NN,
- lambda file: unzip(file.name, to_path=dependencies_path / "cmsis-nn"))
- download(ETHOS_U_CORE_DRIVER,
- lambda file: untar(file.name, to_path=dependencies_path / "core-driver"))
- download(ETHOS_U_CORE_PLATFORM,
- lambda file: untar(file.name, to_path=dependencies_path / "core-platform"))
- download(TF,
- lambda file: unzip(file.name, to_path=dependencies_path / "tensorflow"))
+ for name, url in dependency_urls.items():
+ download(url, dependencies_path / name)
if __name__ == '__main__':
@@ -88,6 +110,6 @@ if __name__ == '__main__':
download_dir = Path(__file__).parent.resolve() / "dependencies"
if download_dir.is_dir():
- logging.info(f'{download_dir} exists. Skipping download.')
+ logging.info('%s exists. Skipping download.', download_dir)
else:
main(download_dir)
diff --git a/model_conditioning_examples/post_training_quantization.py b/model_conditioning_examples/post_training_quantization.py
index a39be0e..42069f5 100644
--- a/model_conditioning_examples/post_training_quantization.py
+++ b/model_conditioning_examples/post_training_quantization.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,28 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with an example of how to perform post-training quantization in TensorFlow.
+This script will provide you with an example of how to perform
+post-training quantization in TensorFlow.
-The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
-integer values.
+The output from this example will be a TensorFlow Lite model file
+where weights and activations are quantized to 8bit integer values.
-Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
-Ethos NPU.
+Quantization helps reduce the size of your models and is necessary
+for running models on certain hardware such as Arm Ethos NPU.
-In addition to quantizing weights, post-training quantization uses a calibration dataset to
-capture the minimum and maximum values of all variable tensors in your model.
-By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations.
+In addition to quantizing weights, post-training quantization uses
+a calibration dataset to capture the minimum and maximum values of
+all variable tensors in your model. By capturing these ranges it
+is possible to fully quantize not just the weights of the model
+but also the activations.
-Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should
-be minimal.
+Depending on the model you are quantizing there may be some accuracy loss,
+but for a lot of models the loss should be minimal.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on post-training quantization
-see: https://www.tensorflow.org/lite/performance/post_training_integer_quant
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on post-training quantization see:
+ https://www.tensorflow.org/lite/performance/post_training_integer_quant
"""
+
import pathlib
import numpy as np
@@ -44,7 +50,8 @@ from training_utils import get_data, create_model
def post_training_quantize(keras_model, sample_data):
- """Quantize Keras model using post-training quantization with some sample data.
+ """
+ Quantize Keras model using post-training quantization with some sample data.
TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
@@ -76,8 +83,14 @@ def post_training_quantize(keras_model, sample_data):
return tflite_model
-def evaluate_tflite_model(tflite_save_path, x_test, y_test):
- """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
+# pylint: disable=duplicate-code
+def evaluate_tflite_model(
+ tflite_save_path: pathlib.Path,
+ x_test: np.ndarray,
+ y_test: np.ndarray
+):
+ """
+ Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
Args:
tflite_save_path: Path to TensorFlow Lite model to test.
@@ -106,6 +119,9 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test):
def main():
+ """
+ Run post-training quantization
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
@@ -117,7 +133,7 @@ def main():
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the fp32 model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy float: {test_acc:.3f}")
# Quantize and export the resulting TensorFlow Lite model to file.
@@ -132,7 +148,12 @@ def main():
# Test the quantized model accuracy. Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
+# pylint: enable=duplicate-code
if __name__ == "__main__":
diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py
index 3d492a7..d590763 100644
--- a/model_conditioning_examples/quantization_aware_training.py
+++ b/model_conditioning_examples/quantization_aware_training.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,31 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the
+This script will provide you with a short example of how to perform
+quantization aware training in TensorFlow using the
TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
-integer values.
+The output from this example will be a TensorFlow Lite model file
+where weights and activations are quantized to 8bit integer values.
-Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
-Ethos NPU.
+Quantization helps reduce the size of your models and is necessary
+for running models on certain hardware such as Arm Ethos NPU.
-In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using
-fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted
-weights and minimize accuracy losses caused by the reduced precision.
+In quantization aware training (QAT), the error introduced with
+quantizing from fp32 to int8 is simulated using fake quantization nodes.
+By simulating this quantization error when training,
+the model can learn better adapted weights and minimize accuracy losses
+caused by the reduced precision.
-Minimum and maximum values for activations are also captured during training so activations for every layer can be
-quantized along with the weights later.
+Minimum and maximum values for activations are also captured
+during training so activations for every layer can be quantized
+along with the weights later.
-Quantization is only simulated during training and the training backward passes are still performed in full float
-precision. Actual quantization happens when generating a TensorFlow Lite model.
+Quantization is only simulated during training and the
+training backward passes are still performed in full float precision.
+Actual quantization happens when generating a TensorFlow Lite model.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on quantization aware training
-see: https://www.tensorflow.org/model_optimization/guide/quantization/training
+For more information on using vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on quantization aware training see:
+ https://www.tensorflow.org/model_optimization/guide/quantization/training
"""
import pathlib
@@ -64,13 +71,15 @@ def quantize_and_convert_to_tflite(keras_model):
# After doing quantization aware training all the information for creating a fully quantized
# TensorFlow Lite model is already within the quantization aware Keras model.
- # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model.
+ # This means we only need to call convert with default optimizations to
+ # generate the quantized TensorFlow Lite model.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
return tflite_model
+# pylint: disable=duplicate-code
def evaluate_tflite_model(tflite_save_path, x_test, y_test):
"""Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
@@ -101,13 +110,19 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test):
def main():
+ """
+ Run quantization aware training
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
- # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our
- # model quantization aware in one line. Once this is done we compile the model and train as normal.
- # It is important to note that the model is only quantization aware and is not quantized yet. The weights are
- # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on.
+ # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit
+ # we can make our model quantization aware in one line.
+ # Once this is done we compile the model and train as normal.
+ # It is important to note that the model is only quantization aware
+ # and is not quantized yet.
+ # The weights are still floating point and will only be converted
+ # to int8 when we generate the TensorFlow Lite model later on.
quant_aware_model = tfmot.quantization.keras.quantize_model(model)
quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -117,7 +132,7 @@ def main():
quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the quantization aware model accuracy.
- test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)
+ test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy quant aware: {test_acc:.3f}")
# Quantize and save the resulting TensorFlow Lite model to file.
@@ -132,7 +147,12 @@ def main():
# Test quantized model accuracy. Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
+# pylint: enable=duplicate-code
if __name__ == "__main__":
diff --git a/model_conditioning_examples/setup.sh b/model_conditioning_examples/setup.sh
index 92de78a..678f9d3 100644
--- a/model_conditioning_examples/setup.sh
+++ b/model_conditioning_examples/setup.sh
@@ -1,5 +1,7 @@
+#!/bin/bash
+
#----------------------------------------------------------------------------
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,8 +16,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#----------------------------------------------------------------------------
-#!/bin/bash
+
python3 -m venv ./env
+# shellcheck disable=SC1091
source ./env/bin/activate
pip install -U pip
-pip install -r requirements.txt \ No newline at end of file
+pip install -r requirements.txt
diff --git a/model_conditioning_examples/training_utils.py b/model_conditioning_examples/training_utils.py
index a022bd1..2ce94b8 100644
--- a/model_conditioning_examples/training_utils.py
+++ b/model_conditioning_examples/training_utils.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -49,7 +49,8 @@ def create_model():
"""
keras_model = tf.keras.models.Sequential([
- tf.keras.layers.Conv2D(32, 3, padding='same', input_shape=(28, 28, 1), activation=tf.nn.relu),
+ tf.keras.layers.Conv2D(32, 3, padding='same',
+ input_shape=(28, 28, 1), activation=tf.nn.relu),
tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(),
tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
diff --git a/model_conditioning_examples/weight_clustering.py b/model_conditioning_examples/weight_clustering.py
index 6672d53..e966336 100644
--- a/model_conditioning_examples/weight_clustering.py
+++ b/model_conditioning_examples/weight_clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,22 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform clustering of weights (weight sharing) in
-TensorFlow using the TensorFlow Model Optimization Toolkit.
+This script will provide you with a short example of how to perform
+clustering of weights (weight sharing) in TensorFlow
+using the TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into
-16 clusters during training - quantization has then been applied on top of this.
+The output from this example will be a TensorFlow Lite model file
+where weights in each layer have been 'clustered' into 16 clusters
+during training - quantization has then been applied on top of this.
-By clustering the model we can improve compression of the model file. This can be essential for deploying certain
-models on systems with limited resources - such as embedded systems using an Arm Ethos NPU.
+By clustering the model we can improve compression of the model file.
+This can be essential for deploying certain models on systems with
+limited resources - such as embedded systems using an Arm Ethos NPU.
-After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+After performing clustering we do post-training quantization
+to quantize the model and then generate a TensorFlow Lite file.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on clustering see:
+ https://www.tensorflow.org/model_optimization/guide/clustering
"""
import pathlib
@@ -42,39 +49,52 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m
def prepare_for_clustering(keras_model):
"""Prepares a Keras model for clustering."""
- # Choose the number of clusters to use and how to initialize them. Using more clusters will generally
- # reduce accuracy so you will need to find the optimal number for your use-case.
+ # Choose the number of clusters to use and how to initialize them.
+ # Using more clusters will generally reduce accuracy,
+ # so you will need to find the optimal number for your use-case.
number_of_clusters = 16
cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR
- # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that
- # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered.
- clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model,
- number_of_clusters=number_of_clusters,
- cluster_centroids_init=cluster_centroids_init)
+ # Apply the clustering wrapper to the whole model so weights in
+ # every layer will get clustered. You may find that to avoid
+ # too much accuracy loss only certain non-critical layers in
+ # your model should be clustered.
+ clustering_ready_model = tfmot.clustering.keras.cluster_weights(
+ keras_model,
+ number_of_clusters=number_of_clusters,
+ cluster_centroids_init=cluster_centroids_init
+ )
# We must recompile the model after making it ready for clustering.
- clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
- loss=tf.keras.losses.sparse_categorical_crossentropy,
- metrics=['accuracy'])
+ clustering_ready_model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+ loss=tf.keras.losses.sparse_categorical_crossentropy,
+ metrics=['accuracy']
+ )
return clustering_ready_model
def main():
+ """
+ Run weight clustering
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
# Compile and train the model first.
- # In general it is easier to do clustering as a fine-tuning step after the model is fully trained.
- model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
- loss=tf.keras.losses.sparse_categorical_crossentropy,
- metrics=['accuracy'])
+ # In general, it is easier to do clustering as a
+ # fine-tuning step after the model is fully trained.
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+ loss=tf.keras.losses.sparse_categorical_crossentropy,
+ metrics=['accuracy']
+ )
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the trained model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy before clustering: {test_acc:.3f}")
# Prepare the model for clustering.
@@ -88,19 +108,26 @@ def main():
# Remove all variables that clustering only needed in the training phase.
model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model)
- # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file.
+ # Apply post-training quantization on top of the clustering
+ # and save the resulting TensorFlow Lite model to file.
tflite_model = post_training_quantize(model_for_export, x_train)
tflite_models_dir = pathlib.Path('./conditioned_models/')
tflite_models_dir.mkdir(exist_ok=True, parents=True)
- clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite'
+ clustered_quant_model_save_path = \
+ tflite_models_dir / 'clustered_post_training_quant_model.tflite'
with open(clustered_quant_model_save_path, 'wb') as f:
f.write(tflite_model)
- # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data.
+ # Test the clustered quantized model accuracy.
+ # Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ clustered_quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
if __name__ == "__main__":
diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py
index cbf9cf9..303b6df 100644
--- a/model_conditioning_examples/weight_pruning.py
+++ b/model_conditioning_examples/weight_pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,23 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow
+This script will provide you with a short example of how to perform
+magnitude-based weight pruning in TensorFlow
using the TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the
+The output from this example will be a TensorFlow Lite model file
+where ~75% percent of the weights have been 'pruned' to the
value 0 during training - quantization has then been applied on top of this.
-By pruning the model we can improve compression of the model file. This can be essential for deploying certain models
-on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run
-on an Arm Ethos NPU then this pruning can improve the execution time of the model.
+By pruning the model we can improve compression of the model file.
+This can be essential for deploying certain models on systems
+with limited resources - such as embedded systems using Arm Ethos NPU.
+Also, if the pruned model is run on an Arm Ethos NPU then
+this pruning can improve the execution time of the model.
-After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+After pruning is complete we do post-training quantization
+to quantize the model and then generate a TensorFlow Lite file.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on weight pruning see:
+ https://www.tensorflow.org/model_optimization/guide/pruning
"""
import pathlib
@@ -43,13 +51,20 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m
def prepare_for_pruning(keras_model):
"""Prepares a Keras model for pruning."""
- # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout
- # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training.
+ # We use a constant sparsity schedule so the amount of sparsity
+ # in the model is kept at the same percent throughout training.
+ # An alternative is PolynomialDecay where sparsity
+ # can be gradually increased during training.
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0)
- # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid
- # too much accuracy loss only certain non-critical layers in your model should be pruned.
- pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule)
+ # Apply the pruning wrapper to the whole model
+ # so weights in every layer will get pruned.
+ # You may find that to avoid too much accuracy loss only
+ # certain non-critical layers in your model should be pruned.
+ pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(
+ keras_model,
+ pruning_schedule=pruning_schedule
+ )
# We must recompile the model after making it ready for pruning.
pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -60,11 +75,15 @@ def prepare_for_pruning(keras_model):
def main():
+ """
+ Run weight pruning
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
# Compile and train the model first.
- # In general it is easier to do pruning as a fine-tuning step after the model is fully trained.
+ # In general, it is easier to do pruning as a fine-tuning step
+ # after the model is fully trained.
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
@@ -72,7 +91,7 @@ def main():
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the trained model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy before pruning: {test_acc:.3f}")
# Prepare the model for pruning and add the pruning update callback needed in training.
@@ -80,14 +99,23 @@ def main():
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
# Continue training the model but now with pruning applied - remember to pass in the callbacks!
- pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks)
+ pruned_model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=128,
+ epochs=1,
+ verbose=1,
+ shuffle=True,
+ callbacks=callbacks
+ )
test_loss, test_acc = pruned_model.evaluate(x_test, y_test)
print(f"Test accuracy after pruning: {test_acc:.3f}")
# Remove all variables that pruning only needed in the training phase.
model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
- # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file.
+ # Apply post-training quantization on top of the pruning
+ # and save the resulting TensorFlow Lite model to file.
tflite_model = post_training_quantize(model_for_export, x_train)
tflite_models_dir = pathlib.Path('./conditioned_models/')
@@ -97,9 +125,14 @@ def main():
with open(pruned_quant_model_save_path, 'wb') as f:
f.write(tflite_model)
- # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data.
+ # Test the pruned quantized model accuracy.
+ # Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ pruned_quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
if __name__ == "__main__":
diff --git a/scripts/py/check_update_resources_downloaded.py b/scripts/py/check_update_resources_downloaded.py
index 6e4da21..bdd9d62 100644
--- a/scripts/py/check_update_resources_downloaded.py
+++ b/scripts/py/check_update_resources_downloaded.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+Contains methods to check if the downloaded resources need to be refreshed
+"""
+import hashlib
import json
import sys
-import hashlib
+import typing
from argparse import ArgumentParser
from pathlib import Path
-def get_md5sum_for_file(filepath: str) -> str:
+def get_md5sum_for_file(
+ filepath: typing.Union[str, Path]
+) -> str:
"""
Function to calculate md5sum for contents of a given file.
@@ -41,7 +46,7 @@ def get_md5sum_for_file(filepath: str) -> str:
def check_update_resources_downloaded(
- resource_downloaded_dir: str, set_up_script_path: str
+ resource_downloaded_dir: str, set_up_script_path: str
):
"""
Function that check if the resources downloaded need to be refreshed.
@@ -55,27 +60,27 @@ def check_update_resources_downloaded(
metadata_file_path = Path(resource_downloaded_dir) / "resources_downloaded_metadata.json"
if metadata_file_path.is_file():
- with open(metadata_file_path) as metadata_json:
-
+ with open(metadata_file_path, encoding="utf8") as metadata_json:
metadata_dict = json.load(metadata_json)
- md5_key = 'set_up_script_md5sum'
- set_up_script_md5sum_metadata = ''
- if md5_key in metadata_dict.keys():
- set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"]
+ md5_key = 'set_up_script_md5sum'
+ set_up_script_md5sum_metadata = ''
+
+ if md5_key in metadata_dict.keys():
+ set_up_script_md5sum_metadata = metadata_dict["set_up_script_md5sum"]
- set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path)
+ set_up_script_md5sum_current = get_md5sum_for_file(set_up_script_path)
- if set_up_script_md5sum_current == set_up_script_md5sum_metadata:
- return 0
+ if set_up_script_md5sum_current == set_up_script_md5sum_metadata:
+ return 0
- # Return code 1 if the resources need to be refreshed.
- print('Error: hash mismatch!')
- print(f'Metadata: {set_up_script_md5sum_metadata}')
- print(f'Current : {set_up_script_md5sum_current}')
- return 1
+ # Return code 1 if the resources need to be refreshed.
+ print('Error: hash mismatch!')
+ print(f'Metadata: {set_up_script_md5sum_metadata}')
+ print(f'Current : {set_up_script_md5sum_current}')
+ return 1
- # Return error code 2 if the file doesn't exists.
+ # Return error code 2 if the file doesn't exist.
print(f'Error: could not find {metadata_file_path}')
return 2
@@ -99,7 +104,8 @@ if __name__ == "__main__":
raise ValueError(f'Invalid script path: {args.setup_script_path}')
# Check the resources are downloaded as expected
- status = check_update_resources_downloaded(
- args.resource_downloaded_dir,
- args.setup_script_path)
- sys.exit(status)
+ STATUS = check_update_resources_downloaded(
+ args.resource_downloaded_dir,
+ args.setup_script_path
+ )
+ sys.exit(STATUS)
diff --git a/scripts/py/dependency_urls.json b/scripts/py/dependency_urls.json
new file mode 100644
index 0000000..33a84f7
--- /dev/null
+++ b/scripts/py/dependency_urls.json
@@ -0,0 +1,8 @@
+{
+ "cmsis": "https://github.com/ARM-software/CMSIS_5/archive/a75f01746df18bb5b929dfb8dc6c9407fac3a0f3.zip",
+ "cmsis-dsp": "https://github.com/ARM-software/CMSIS-DSP/archive/refs/tags/v1.15.0.zip",
+ "cmsis-nn": "https://github.com/ARM-software/CMSIS-NN/archive/refs/85164a811917770d7027a12a57ed3b469dac6537.zip",
+ "core-driver": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/snapshot/ethos-u-core-driver-23.08.tar.gz",
+ "core-platform": "https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-23.08.tar.gz",
+ "tensorflow": "https://github.com/tensorflow/tflite-micro/archive/568d181ccc1f60e49742fd43b7f97141ee8d45fc.zip"
+}
diff --git a/scripts/py/gen_audio.py b/scripts/py/gen_audio.py
index ff33bfb..4d7318c 100644
--- a/scripts/py/gen_audio.py
+++ b/scripts/py/gen_audio.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,34 +17,99 @@
"""
Utility script to convert an audio clip into eval platform desired spec.
"""
-import soundfile as sf
-
from argparse import ArgumentParser
from os import path
-from gen_utils import AudioUtils
+import soundfile as sf
+
+from gen_utils import GenUtils
parser = ArgumentParser()
-parser.add_argument("--audio_path", help="Audio file path", required=True)
-parser.add_argument("--output_dir", help="Output directory", required=True)
-parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000)
-parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True)
-parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0)
-parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0)
-parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.", default='kaiser_best')
-parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000)
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+
+# pylint: disable=duplicate-code
+parser.add_argument(
+ "--audio_path",
+ help="Audio file path",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory",
+ required=True
+)
+
+parser.add_argument(
+ "--sampling_rate",
+ type=int,
+ help="target sampling rate.",
+ default=16000
+)
+
+parser.add_argument(
+ "--mono",
+ type=bool,
+ help="convert signal to mono.",
+ default=True
+)
+
+parser.add_argument(
+ "--offset",
+ type=float,
+ help="start reading after this time (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--duration",
+ type=float,
+ help="only load up to this much audio (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--res_type",
+ type=GenUtils.res_data_type,
+ help=f"Resample type: {GenUtils.res_type_list()}.",
+ default='kaiser_best'
+)
+
+parser.add_argument(
+ "--min_samples",
+ type=int,
+ help="Minimum sample number.",
+ default=16000
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+# pylint: enable=duplicate-code
+
+parsed_args = parser.parse_args()
def main(args):
- audio_data, samplerate = AudioUtils.load_resample_audio_clip(args.audio_path,
- args.sampling_rate,
- args.mono, args.offset,
- args.duration, args.res_type,
- args.min_samples)
- sf.write(path.join(args.output_dir, path.basename(args.audio_path)), audio_data, samplerate)
+ """
+ Generate the new audio file
+ @param args: Parsed args
+ """
+ audio_sample = GenUtils.read_audio_file(
+ args.audio_path, args.offset, args.duration
+ )
+
+ resampled_audio = GenUtils.resample_audio_clip(
+ audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples
+ )
+
+ sf.write(
+ path.join(args.output_dir, path.basename(args.audio_path)),
+ resampled_audio.data,
+ resampled_audio.sample_rate
+ )
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_audio_cpp.py b/scripts/py/gen_audio_cpp.py
index 850a871..89d9ae1 100644
--- a/scripts/py/gen_audio_cpp.py
+++ b/scripts/py/gen_audio_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,93 +21,217 @@ from the cpp files.
import datetime
import glob
import math
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
import numpy as np
from jinja2 import Environment, FileSystemLoader
-from gen_utils import AudioUtils
+from gen_utils import GenUtils, AudioSample
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--audio_path", type=str, help="path to audio folder to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--sampling_rate", type=int, help="target sampling rate.", default=16000)
-parser.add_argument("--mono", type=bool, help="convert signal to mono.", default=True)
-parser.add_argument("--offset", type=float, help="start reading after this time (in seconds).", default=0)
-parser.add_argument("--duration", type=float, help="only load up to this much audio (in seconds).", default=0)
-parser.add_argument("--res_type", type=AudioUtils.res_data_type, help=f"Resample type: {AudioUtils.res_type_list()}.",
- default='kaiser_best')
-parser.add_argument("--min_samples", type=int, help="Minimum sample number.", default=16000)
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--audio_path",
+ type=str,
+ help="path to audio folder to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--sampling_rate",
+ type=int,
+ help="target sampling rate.",
+ default=16000
+)
+
+parser.add_argument(
+ "--mono",
+ type=bool,
+ help="convert signal to mono.",
+ default=True
+)
+
+parser.add_argument(
+ "--offset",
+ type=float,
+ help="start reading after this time (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--duration",
+ type=float,
+ help="only load up to this much audio (in seconds).",
+ default=0
+)
+
+parser.add_argument(
+ "--res_type",
+ type=GenUtils.res_data_type,
+ help=f"Resample type: {GenUtils.res_type_list()}.",
+ default='kaiser_best'
+)
+
+parser.add_argument(
+ "--min_samples",
+ type=int,
+ help="Minimum sample number.",
+ default=16000
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_filepath, cc_filepath, header_template_file, num_audios, audio_filenames, audio_array_namesizes):
+# pylint: enable=duplicate-code
+def write_hpp_file(
+ header_filepath,
+ header,
+ num_audios,
+ audio_array_namesizes
+):
+ """
+ Write audio hpp file
+
+ @param header_filepath: .hpp filepath
+ @param header: Rendered header
+ @param num_audios: Audio file index
+ @param audio_array_namesizes: Audio array name sizes
+ """
print(f"++ Generating {header_filepath}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('AudioClips.hpp.template').stream(common_template_header=hdr,
- clips_count=num_audios,
- varname_size=audio_array_namesizes
- ) \
+ env \
+ .get_template('AudioClips.hpp.template') \
+ .stream(common_template_header=header,
+ clips_count=num_audios,
+ varname_size=audio_array_namesizes) \
.dump(str(header_filepath))
+
+def write_cc_file(
+ cc_filepath,
+ header,
+ num_audios,
+ audio_filenames,
+ audio_array_namesizes
+):
+ """
+ Write cc file
+
+ @param cc_filepath: .cc filepath
+ @param header: Rendered header
+ @param num_audios: Audio file index
+ @param audio_filenames: Audio filenames
+ @param audio_array_namesizes: Audio array name sizes
+ """
print(f"++ Generating {cc_filepath}")
- env.get_template('AudioClips.cc.template').stream(common_template_header=hdr,
- clips_count=num_audios,
- var_names=(name for name, _ in audio_array_namesizes),
- clip_sizes=(size for _, size in audio_array_namesizes),
- clip_names=audio_filenames) \
+ env \
+ .get_template('AudioClips.cc.template') \
+ .stream(common_template_header=header,
+ clips_count=num_audios,
+ var_names=(name for name, _ in audio_array_namesizes),
+ clip_sizes=(size for _, size in audio_array_namesizes),
+ clip_names=audio_filenames) \
.dump(str(cc_filepath))
-def write_individual_audio_cc_file(clip_dirpath, clip_filename,
- cc_filename, header_template_file, array_name,
- sampling_rate_value, mono_value, offset_value,
- duration_value, res_type_value, min_len):
+def write_individual_audio_cc_file(
+ resampled_audio: AudioSample,
+ clip_filename,
+ cc_filename,
+ header_template_file,
+ array_name
+):
+ """
+ Writes the provided audio sample to a .cc file
+
+ @param resampled_audio: Audio sample to write
+ @param clip_filename: File name of the clip
+ @param cc_filename: File name of the .cc file
+ @param header_template_file: Header template
+ @param array_name: Name of the array to write
+ @return: Array length of the audio data written
+ """
print(f"++ Converting {clip_filename} to {Path(cc_filename).name}")
- audio_filepath = Path(clip_dirpath) / clip_filename
- clip_data, samplerate = AudioUtils.load_resample_audio_clip(audio_filepath,
- sampling_rate_value, mono_value,
- offset_value, duration_value,
- res_type_value, min_len)
# Change from [-1, 1] fp32 range to int16 range.
- clip_data = np.clip((clip_data * (1 << 15)),
+ clip_data = np.clip((resampled_audio.data * (1 << 15)),
np.iinfo(np.int16).min,
np.iinfo(np.int16).max).flatten().astype(np.int16)
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=clip_filename,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, header_template_file, clip_filename)
hex_line_generator = (', '.join(map(hex, sub_arr))
- for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data)/20)))
+ for sub_arr in np.array_split(clip_data, math.ceil(len(clip_data) / 20)))
- env.get_template('audio.cc.template').stream(common_template_header=hdr,
- size=len(clip_data),
- var_name=array_name,
- audio_data=hex_line_generator) \
+ env \
+ .get_template('audio.cc.template') \
+ .stream(common_template_header=hdr,
+ size=len(clip_data),
+ var_name=array_name,
+ audio_data=hex_line_generator) \
.dump(str(cc_filename))
return len(clip_data)
+def create_audio_cc_file(args, filename, array_name, clip_dirpath):
+ """
+ Create an individual audio cpp file
+
+ @param args: User-specified args
+ @param filename: Audio filename
+ @param array_name: Name of the array in the audio .cc file
+ @param clip_dirpath: Audio file directory path
+ @return: Array length of the audio data written
+ """
+ cc_filename = (Path(args.source_folder_path) /
+ (Path(filename).stem.replace(" ", "_") + ".cc"))
+ audio_filepath = Path(clip_dirpath) / filename
+ audio_sample = GenUtils.read_audio_file(audio_filepath, args.offset, args.duration)
+ resampled_audio = GenUtils.resample_audio_clip(
+ audio_sample, args.sampling_rate, args.mono, args.res_type, args.min_samples
+ )
+ return write_individual_audio_cc_file(
+ resampled_audio, filename, cc_filename, args.license_template, array_name,
+ )
+
+
def main(args):
+ """
+ Convert audio files to .cc + .hpp files
+ @param args: Parsed args
+ """
# Keep the count of the audio files converted
audioclip_idx = 0
audioclip_filenames = []
@@ -131,25 +255,41 @@ def main(args):
audioclip_filenames.append(filename)
# Save the cc file
- cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc")
array_name = "audio" + str(audioclip_idx)
- array_size = write_individual_audio_cc_file(clip_dirpath, filename, cc_filename, args.license_template, array_name,
- args.sampling_rate, args.mono, args.offset,
- args.duration, args.res_type, args.min_samples)
+ array_size = create_audio_cc_file(args, filename, array_name, clip_dirpath)
audioclip_array_names.append((array_name, array_size))
# Increment audio index
audioclip_idx = audioclip_idx + 1
- except:
+ except OSError:
if args.verbosity:
print(f"Failed to open {filename} as an audio.")
if len(audioclip_filenames) > 0:
- write_hpp_file(header_filepath, common_cc_filepath, args.license_template,
- audioclip_idx, audioclip_filenames, audioclip_array_names)
+ header = env \
+ .get_template(args.license_template) \
+ .render(script_name=Path(__file__).name,
+ gen_time=datetime.datetime.now(),
+ year=datetime.datetime.now().year)
+
+ write_hpp_file(
+ header_filepath,
+ header,
+ audioclip_idx,
+ audioclip_array_names
+ )
+
+ write_cc_file(
+ common_cc_filepath,
+ header,
+ audioclip_idx,
+ audioclip_filenames,
+ audioclip_array_names
+ )
+
else:
raise FileNotFoundError("No valid audio clip files found.")
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_default_input_cpp.py b/scripts/py/gen_default_input_cpp.py
index 093a606..6056dc1 100644
--- a/scripts/py/gen_default_input_cpp.py
+++ b/scripts/py/gen_default_input_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,38 +16,61 @@
"""
Utility script to generate the minimum InputFiles.hpp and cpp files required by an application.
"""
-import datetime
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
parser = ArgumentParser()
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+
+# pylint: disable=duplicate-code
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def write_hpp_file(header_file_path, header_template_file):
+ """
+ Write .hpp file
+ @param header_file_path: Header file path
+ @param header_template_file: Header template file
+ """
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('default.hpp.template').stream(common_template_header=hdr) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+ env \
+ .get_template('default.hpp.template') \
+ .stream(common_template_header=hdr) \
.dump(str(header_file_path))
def main(args):
+ """
+ Generate InputFiles.hpp + .cpp
+ @param args: Parsed args
+ """
header_filename = "InputFiles.hpp"
header_filepath = Path(args.header_folder_path) / header_filename
write_hpp_file(header_filepath, args.license_template)
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_labels_cpp.py b/scripts/py/gen_labels_cpp.py
index 065ed5d..11d5040 100644
--- a/scripts/py/gen_labels_cpp.py
+++ b/scripts/py/gen_labels_cpp.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,47 +21,83 @@ NN model output vector) into a vector list initialiser. The intention is for
this script to be called as part of the build framework to auto-generate the
cpp file with labels that can be used in the application without modification.
"""
-import datetime
-from pathlib import Path
from argparse import ArgumentParser
+from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
# Label file path
-parser.add_argument("--labels_file", type=str, help="Path to the label text file", required=True)
+parser.add_argument(
+ "--labels_file",
+ type=str,
+ help="Path to the label text file",
+ required=True
+)
+
# Output file to be generated
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.", required=True)
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.", required=True)
-parser.add_argument("--output_file_name", type=str, help="Required output file name", required=True)
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated.",
+ required=True
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated.",
+ required=True
+)
+
+parser.add_argument(
+ "--output_file_name",
+ type=str,
+ help="Required output file name",
+ required=True
+)
+
# Namespaces
-parser.add_argument("--namespaces", action='append', default=[])
+parser.add_argument(
+ "--namespaces",
+ action='append',
+ default=[]
+)
+
# License template
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
-args = parser.parse_args()
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def main(args):
+ """
+ Generate labels .cpp
+ @param args: Parsed args
+ """
# Get the labels from text file
- with open(args.labels_file, "r") as f:
+ with open(args.labels_file, "r", encoding="utf8") as f:
labels = f.read().splitlines()
# No labels?
if len(labels) == 0:
- raise Exception(f"no labels found in {args.label_file}")
+ raise ValueError(f"no labels found in {args.label_file}")
- header_template = env.get_template(args.license_template)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=Path(args.labels_file).name,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, args.license_template, Path(args.labels_file).name)
hpp_filename = Path(args.header_folder_path) / (args.output_file_name + ".hpp")
env.get_template('Labels.hpp.template').stream(common_template_header=hdr,
@@ -78,4 +114,4 @@ def main(args):
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index e4933b5..933c189 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the
project directly. This should be called as part of cmake framework
should the models need to be generated at configuration stage.
"""
-import datetime
+import binascii
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
-import binascii
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
-parser.add_argument("--output_dir", help="Output directory", required=True)
-parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
-parser.add_argument('--header', action='append', default=[], dest="headers")
-parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+parser.add_argument(
+ "--tflite_path",
+ help="Model (.tflite) path",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory",
+ required=True
+)
+
+parser.add_argument(
+ '-e',
+ '--expression',
+ action='append',
+ default=[],
+ dest="expr"
+)
+
+parser.add_argument(
+ '--header',
+ action='append',
+ default=[],
+ dest="headers"
+)
+
+parser.add_argument(
+ '-ns',
+ '--namespaces',
+ action='append',
+ default=[],
+ dest="namespaces"
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def get_tflite_data(tflite_path: str) -> list:
"""
Reads a binary file and returns a C style array as a
@@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list:
for i in range(0, len(hexstream), 2):
if 0 == (i % hex_digits_per_line):
hexstring += "\n"
- hexstring += '0x' + hexstream[i:i+2] + ", "
+ hexstring += '0x' + hexstream[i:i + 2] + ", "
hexstring += '};\n'
return [hexstring]
def main(args):
+ """
+ Generate models .cpp
+ @param args: Parsed args
+ """
if not Path(args.tflite_path).is_file():
- raise Exception(f"{args.tflite_path} not found")
+ raise ValueError(f"{args.tflite_path} not found")
# Cpp filename:
cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
@@ -80,19 +122,16 @@ def main(args):
cpp_filename.parent.mkdir(exist_ok=True)
- header_template = env.get_template(args.license_template)
-
- hdr = header_template.render(script_name=Path(__file__).name,
- file_name=Path(args.tflite_path).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name)
- env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=get_tflite_data(args.tflite_path),
- expressions=args.expr,
- additional_headers=args.headers,
- namespaces=args.namespaces).dump(str(cpp_filename))
+ env \
+ .get_template('tflite.cc.template') \
+ .stream(common_template_header=hdr,
+ model_data=get_tflite_data(args.tflite_path),
+ expressions=args.expr,
+ additional_headers=args.headers,
+ namespaces=args.namespaces).dump(str(cpp_filename))
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_rgb_cpp.py b/scripts/py/gen_rgb_cpp.py
index b8d85ee..e1c93bb 100644
--- a/scripts/py/gen_rgb_cpp.py
+++ b/scripts/py/gen_rgb_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,87 +18,179 @@ Utility script to convert a set of RGB images in a given location into
corresponding cpp files and a single hpp file referencing the vectors
from the cpp files.
"""
-import datetime
import glob
import math
-from pathlib import Path
+import typing
from argparse import ArgumentParser
+from dataclasses import dataclass
+from pathlib import Path
import numpy as np
from PIL import Image, UnidentifiedImageError
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--image_path", type=str, help="path to images folder or image file to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--image_size", type=int, nargs=2, help="Size (width and height) of the converted images.")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--image_path",
+ type=str,
+ help="path to images folder or image file to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--image_size",
+ type=int,
+ nargs=2,
+ help="Size (width and height) of the converted images."
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_file_path, cc_file_path, header_template_file, num_images, image_filenames,
- image_array_names, image_size):
+# pylint: enable=duplicate-code
+@dataclass
+class ImagesParams:
+ """
+ Template params for Images.hpp and Images.cc
+ """
+ num_images: int
+ image_size: typing.Sequence
+ image_array_names: typing.List[str]
+ image_filenames: typing.List[str]
+
+
+def write_hpp_file(
+ images_params: ImagesParams,
+ header_file_path: Path,
+ cc_file_path: Path,
+ header_template_file: str,
+):
+ """
+ Write Images.hpp and Images.cc
+
+ @param images_params: Template params
+ @param header_file_path: Images.hpp path
+ @param cc_file_path: Images.cc path
+ @param header_template_file: Header template file name
+ """
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('Images.hpp.template').stream(common_template_header=hdr,
- imgs_count=num_images,
- img_size=str(image_size[0] * image_size[1] * 3),
- var_names=image_array_names) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+
+ image_size = str(images_params.image_size[0] * images_params.image_size[1] * 3)
+
+ env \
+ .get_template('Images.hpp.template') \
+ .stream(common_template_header=hdr,
+ imgs_count=images_params.num_images,
+ img_size=image_size,
+ var_names=images_params.image_array_names) \
.dump(str(header_file_path))
- env.get_template('Images.cc.template').stream(common_template_header=hdr,
- var_names=image_array_names,
- img_names=image_filenames) \
+ env \
+ .get_template('Images.cc.template') \
+ .stream(common_template_header=hdr,
+ var_names=images_params.image_array_names,
+ img_names=images_params.image_filenames) \
.dump(str(cc_file_path))
-def write_individual_img_cc_file(image_filename, cc_filename, header_template_file, original_image,
- image_size, array_name):
- print(f"++ Converting {image_filename} to {cc_filename.name}")
+def resize_crop_image(
+ original_image: Image.Image,
+ image_size: typing.Sequence
+) -> np.ndarray:
+ """
+ Resize and crop input image
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=image_filename,
- year=datetime.datetime.now().year)
+ @param original_image: Image to resize and crop
+ @param image_size: New image size
+ @return: Resized and cropped image
+ """
# IFM size
ifm_width = image_size[0]
ifm_height = image_size[1]
# Aspect ratio resize
- scale_ratio = (float)(max(ifm_width, ifm_height)) / (float)(min(original_image.size[0], original_image.size[1]))
- resized_width = (int)(original_image.size[0] * scale_ratio)
- resized_height = (int)(original_image.size[1] * scale_ratio)
- resized_image = original_image.resize([resized_width,resized_height], Image.Resampling.BILINEAR)
+ scale_ratio = (float(max(ifm_width, ifm_height))
+ / float(min(original_image.size[0], original_image.size[1])))
+ resized_width = int(original_image.size[0] * scale_ratio)
+ resized_height = int(original_image.size[1] * scale_ratio)
+ resized_image = original_image.resize(
+ size=(resized_width, resized_height),
+ resample=Image.Resampling.BILINEAR
+ )
# Crop the center of the image
resized_image = resized_image.crop((
- (resized_width - ifm_width) / 2, # left
- (resized_height - ifm_height) / 2, # top
- (resized_width + ifm_width) / 2, # right
+ (resized_width - ifm_width) / 2, # left
+ (resized_height - ifm_height) / 2, # top
+ (resized_width + ifm_width) / 2, # right
(resized_height + ifm_height) / 2 # bottom
- ))
+ ))
+
+ return np.array(resized_image, dtype=np.uint8).flatten()
+
+
+def write_individual_img_cc_file(
+ rgb_data: np.ndarray,
+ image_filename: str,
+ cc_filename: Path,
+ header_template_file: str,
+ array_name: str
+):
+ """
+ Write image.cc
+
+ @param rgb_data: Image data
+ @param image_filename: Image file name
+ @param cc_filename: image.cc path
+ @param header_template_file: Header template file name
+ @param array_name: C++ array name
+ """
+ print(f"++ Converting {image_filename} to {cc_filename.name}")
+
+ hdr = GenUtils.gen_header(env, header_template_file, image_filename)
- # Convert the image and write it to the cc file
- rgb_data = np.array(resized_image, dtype=np.uint8).flatten()
hex_line_generator = (', '.join(map(hex, sub_arr))
for sub_arr in np.array_split(rgb_data, math.ceil(len(rgb_data) / 20)))
- env.get_template('image.cc.template').stream(common_template_header=hdr,
- var_name=array_name,
- img_data=hex_line_generator) \
+ env \
+ .get_template('image.cc.template') \
+ .stream(common_template_header=hdr,
+ var_name=array_name,
+ img_data=hex_line_generator) \
.dump(str(cc_filename))
def main(args):
+ """
+ Convert images
+ @param args: Parsed args
+ """
# Keep the count of the images converted
image_idx = 0
image_filenames = []
@@ -123,26 +215,29 @@ def main(args):
image_filenames.append(filename)
# Save the cc file
- cc_filename = Path(args.source_folder_path) / (Path(filename).stem.replace(" ", "_") + ".cc")
+ cc_filename = (Path(args.source_folder_path) /
+ (Path(filename).stem.replace(" ", "_") + ".cc"))
array_name = "im" + str(image_idx)
image_array_names.append(array_name)
- write_individual_img_cc_file(filename, cc_filename, args.license_template,
- original_image, args.image_size, array_name)
+
+ rgb_data = resize_crop_image(original_image, args.image_size)
+ write_individual_img_cc_file(
+ rgb_data, filename, cc_filename, args.license_template, array_name
+ )
# Increment image index
image_idx = image_idx + 1
- header_filename = "InputFiles.hpp"
- header_filepath = Path(args.header_folder_path) / header_filename
- common_cc_filename = "InputFiles.cc"
- common_cc_filepath = Path(args.source_folder_path) / common_cc_filename
+ header_filepath = Path(args.header_folder_path) / "InputFiles.hpp"
+ common_cc_filepath = Path(args.source_folder_path) / "InputFiles.cc"
+
+ images_params = ImagesParams(image_idx, args.image_size, image_array_names, image_filenames)
if len(image_filenames) > 0:
- write_hpp_file(header_filepath, common_cc_filepath, args.license_template,
- image_idx, image_filenames, image_array_names, args.image_size)
+ write_hpp_file(images_params, header_filepath, common_cc_filepath, args.license_template)
else:
raise FileNotFoundError("No valid images found.")
if __name__ == '__main__':
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/gen_test_data_cpp.py b/scripts/py/gen_test_data_cpp.py
index a9e2b75..1ee55ff 100644
--- a/scripts/py/gen_test_data_cpp.py
+++ b/scripts/py/gen_test_data_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 - 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,81 +18,170 @@ Utility script to convert a set of pairs of npy files in a given location into
corresponding cpp files and a single hpp file referencing the vectors
from the cpp files.
"""
-import datetime
import math
-import os
-import numpy as np
+import typing
+from argparse import ArgumentParser
+from dataclasses import dataclass
from pathlib import Path
-from argparse import ArgumentParser
+import numpy as np
from jinja2 import Environment, FileSystemLoader
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--data_folder_path", type=str, help="path to ifm-ofm npy folder to convert.")
-parser.add_argument("--source_folder_path", type=str, help="path to source folder to be generated.")
-parser.add_argument("--header_folder_path", type=str, help="path to header folder to be generated.")
-parser.add_argument("--usecase", type=str, default="", help="Test data file suffix.")
-parser.add_argument("--namespaces", action='append', default=[])
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-parser.add_argument("-v", "--verbosity", action="store_true")
-args = parser.parse_args()
+parser.add_argument(
+ "--data_folder_path",
+ type=str,
+ help="path to ifm-ofm npy folder to convert."
+)
+
+parser.add_argument(
+ "--source_folder_path",
+ type=str,
+ help="path to source folder to be generated."
+)
+
+parser.add_argument(
+ "--header_folder_path",
+ type=str,
+ help="path to header folder to be generated."
+)
+
+parser.add_argument(
+ "--usecase",
+ type=str,
+ default="",
+ help="Test data file suffix."
+)
+
+parser.add_argument(
+ "--namespaces",
+ action='append',
+ default=[]
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
-def write_hpp_file(header_filename, cc_file_path, header_template_file, num_ifms, num_ofms,
- ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type):
- header_file_path = Path(args.header_folder_path) / header_filename
+# pylint: enable=duplicate-code
+@dataclass
+class TestDataParams:
+ """
+ Template params for TestData.hpp + TestData.ccc
+ """
+ ifm_count: int
+ ofm_count: int
+ ifm_var_names: typing.List[str]
+ ifm_var_sizes: typing.List[int]
+ ofm_var_names: typing.List[str]
+ ofm_var_sizes: typing.List[int]
+ data_type: str
+
+
+@dataclass
+class IofmParams:
+ """
+ Template params for iofmdata.cc
+ """
+ var_name: str
+ data_type: str
+
+
+def write_hpp_file(
+ template_params: TestDataParams,
+ header_filename: str,
+ cc_file_path: str,
+ header_template_file: str
+):
+ """
+ Write TestData.hpp and TestData.cc
+
+ @param template_params: Template parameters
+ @param header_filename: TestData.hpp path
+ @param cc_file_path: TestData.cc path
+ @param header_template_file: Header template file name
+ """
+ header_file_path = Path(parsed_args.header_folder_path) / header_filename
print(f"++ Generating {header_file_path}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
- env.get_template('TestData.hpp.template').stream(common_template_header=hdr,
- ifm_count=num_ifms,
- ofm_count=num_ofms,
- ifm_var_names=ifm_array_names,
- ifm_var_sizes=ifm_sizes,
- ofm_var_names=ofm_array_names,
- ofm_var_sizes=ofm_sizes,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ hdr = GenUtils.gen_header(env, header_template_file)
+ env \
+ .get_template('TestData.hpp.template') \
+ .stream(common_template_header=hdr,
+ ifm_count=template_params.ifm_count,
+ ofm_count=template_params.ofm_count,
+ ifm_var_names=template_params.ifm_var_names,
+ ifm_var_sizes=template_params.ifm_var_sizes,
+ ofm_var_names=template_params.ofm_var_names,
+ ofm_var_sizes=template_params.ofm_var_sizes,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(header_file_path))
- env.get_template('TestData.cc.template').stream(common_template_header=hdr,
- include_h=header_filename,
- ifm_var_names=ifm_array_names,
- ofm_var_names=ofm_array_names,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ env \
+ .get_template('TestData.cc.template') \
+ .stream(common_template_header=hdr,
+ include_h=header_filename,
+ ifm_var_names=template_params.ifm_var_names,
+ ofm_var_names=template_params.ofm_var_names,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(cc_file_path))
-def write_individual_cc_file(filename, cc_filename, header_filename, header_template_file, array_name, iofm_data_type):
+def write_individual_cc_file(
+ template_params: IofmParams,
+ header_filename: str,
+ filename: str,
+ cc_filename: Path,
+ header_template_file: str
+):
+ """
+ Write iofmdata.cc
+
+ @param template_params: Template parameters
+ @param header_filename: Header file name
+ @param filename: Input file name
+ @param cc_filename: iofmdata.cc file name
+ @param header_template_file: Header template file name
+ """
print(f"++ Converting {filename} to {cc_filename.name}")
- header_template = env.get_template(header_template_file)
- hdr = header_template.render(script_name=Path(__file__).name,
- gen_time=datetime.datetime.now(),
- file_name=filename,
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, header_template_file, filename)
# Convert the image and write it to the cc file
- fm_data = (np.load(Path(args.data_folder_path) / filename)).flatten()
+ fm_data = (np.load(Path(parsed_args.data_folder_path) / filename)).flatten()
type(fm_data.dtype)
hex_line_generator = (', '.join(map(hex, sub_arr))
for sub_arr in np.array_split(fm_data, math.ceil(len(fm_data) / 20)))
- env.get_template('iofmdata.cc.template').stream(common_template_header=hdr,
- include_h=header_filename,
- var_name=array_name,
- fm_data=hex_line_generator,
- data_type=iofm_data_type,
- namespaces=args.namespaces) \
+ env \
+ .get_template('iofmdata.cc.template') \
+ .stream(common_template_header=hdr,
+ include_h=header_filename,
+ var_name=template_params.var_name,
+ fm_data=hex_line_generator,
+ data_type=template_params.data_type,
+ namespaces=parsed_args.namespaces) \
.dump(str(cc_filename))
@@ -104,59 +193,117 @@ def get_npy_vec_size(filename: str) -> int:
Return:
size in bytes
"""
- data = np.load(Path(args.data_folder_path) / filename)
+ data = np.load(Path(parsed_args.data_folder_path) / filename)
return data.size * data.dtype.itemsize
-def main(args):
- # Keep the count of the images converted
- ifm_array_names = []
- ofm_array_names = []
+def write_cc_files(args, count, iofm_data_type, add_usecase_fname, prefix):
+ """
+ Write all cc files
+
+ @param args: User-provided args
+ @param count: File count
+ @param iofm_data_type: Data type
+ @param add_usecase_fname: Use case suffix
+ @param prefix: Prefix (ifm/ofm)
+ @return: Names and sizes of generated C++ arrays
+ """
+ array_names = []
+ sizes = []
+
+ header_filename = get_header_filename(add_usecase_fname)
+ # In the data_folder_path there should be pairs of ifm-ofm
+ # It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy
+ # count = int(len(list(Path(args.data_folder_path).glob(f'{prefix}*.npy'))))
+
+ for idx in range(count):
+ # Save the fm cc file
+ base_name = prefix + str(idx)
+ filename = base_name + ".npy"
+ array_name = base_name + add_usecase_fname
+ cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
+ array_names.append(array_name)
+
+ template_params = IofmParams(
+ var_name=array_name,
+ data_type=iofm_data_type,
+ )
+
+ write_individual_cc_file(
+ template_params, header_filename, filename, cc_filename, args.license_template
+ )
+ sizes.append(get_npy_vec_size(filename))
+
+ return array_names, sizes
+
+
+def get_header_filename(use_case_filename):
+ """
+ Get the header file name from the use case file name
+
+ @param use_case_filename: The use case file name
+ @return: The header file name
+ """
+ return "TestData" + use_case_filename + ".hpp"
+
+
+def get_cc_filename(use_case_filename):
+ """
+ Get the cc file name from the use case file name
+
+ @param use_case_filename: The use case file name
+ @return: The cc file name
+ """
+ return "TestData" + use_case_filename + ".cc"
+
+
+def main(args):
+ """
+ Generate test data
+ @param args: Parsed args
+ """
add_usecase_fname = ("_" + args.usecase) if (args.usecase != "") else ""
- header_filename = "TestData" + add_usecase_fname + ".hpp"
- common_cc_filename = "TestData" + add_usecase_fname + ".cc"
+ header_filename = get_header_filename(add_usecase_fname)
+ common_cc_filename = get_cc_filename(add_usecase_fname)
# In the data_folder_path there should be pairs of ifm-ofm
# It's assumed the ifm-ofm naming convention: ifm0.npy-ofm0.npy, ifm1.npy-ofm1.npy
ifms_count = int(len(list(Path(args.data_folder_path).glob('ifm*.npy'))))
ofms_count = int(len(list(Path(args.data_folder_path).glob('ofm*.npy'))))
- #i_ofms_count = int(len([name for name in os.listdir(os.path.join(args.data_folder_path)) if name.lower().endswith('.npy')]) / 2)
-
iofm_data_type = "int8_t"
if ifms_count > 0:
- iofm_data_type = "int8_t" if (np.load(Path(args.data_folder_path) / "ifm0.npy").dtype == np.int8) else "uint8_t"
-
- ifm_sizes = []
- ofm_sizes = []
+ iofm_data_type = "int8_t" \
+ if (np.load(str(Path(args.data_folder_path) / "ifm0.npy")).dtype == np.int8) \
+ else "uint8_t"
- for idx in range(ifms_count):
- # Save the fm cc file
- base_name = "ifm" + str(idx)
- filename = base_name+".npy"
- array_name = base_name + add_usecase_fname
- cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
- ifm_array_names.append(array_name)
- write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type)
- ifm_sizes.append(get_npy_vec_size(filename))
+ ifm_array_names, ifm_sizes = write_cc_files(
+ args, ifms_count, iofm_data_type, add_usecase_fname, prefix="ifm"
+ )
- for idx in range(ofms_count):
- # Save the fm cc file
- base_name = "ofm" + str(idx)
- filename = base_name+".npy"
- array_name = base_name + add_usecase_fname
- cc_filename = Path(args.source_folder_path) / (array_name + ".cc")
- ofm_array_names.append(array_name)
- write_individual_cc_file(filename, cc_filename, header_filename, args.license_template, array_name, iofm_data_type)
- ofm_sizes.append(get_npy_vec_size(filename))
+ ofm_array_names, ofm_sizes = write_cc_files(
+ args, ofms_count, iofm_data_type, add_usecase_fname, prefix="ofm"
+ )
common_cc_filepath = Path(args.source_folder_path) / common_cc_filename
- write_hpp_file(header_filename, common_cc_filepath, args.license_template,
- ifms_count, ofms_count, ifm_array_names, ifm_sizes, ofm_array_names, ofm_sizes, iofm_data_type)
+
+ template_params = TestDataParams(
+ ifm_count=ifms_count,
+ ofm_count=ofms_count,
+ ifm_var_names=ifm_array_names,
+ ifm_var_sizes=ifm_sizes,
+ ofm_var_names=ofm_array_names,
+ ofm_var_sizes=ofm_sizes,
+ data_type=iofm_data_type,
+ )
+
+ write_hpp_file(
+ template_params, header_filename, common_cc_filepath, args.license_template
+ )
if __name__ == '__main__':
- if args.verbosity:
- print("Running gen_test_data_cpp with args: "+str(args))
- main(args)
+ if parsed_args.verbosity:
+ print("Running gen_test_data_cpp with args: " + str(parsed_args))
+ main(parsed_args)
diff --git a/scripts/py/gen_utils.py b/scripts/py/gen_utils.py
index ee33705..6bb4760 100644
--- a/scripts/py/gen_utils.py
+++ b/scripts/py/gen_utils.py
@@ -1,6 +1,6 @@
#!env/bin/python3
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,21 +14,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import soundfile as sf
-import resampy
+"""
+Utility functions for .cc + .hpp file generation
+"""
+import argparse
+import datetime
+from dataclasses import dataclass
+from pathlib import Path
+
+import jinja2
import numpy as np
+import resampy
+import soundfile as sf
-class AudioUtils:
+@dataclass
+class AudioSample:
+ """
+ Represents an audio sample with its sample rate
+ """
+ data: np.ndarray
+ sample_rate: int
+
+
+class GenUtils:
+ """
+ Class with utility functions for audio and other .cc + .hpp file generation
+ """
+
@staticmethod
def res_data_type(res_type_value):
"""
Returns the input string if is one of the valid resample type
"""
- import argparse
- if res_type_value not in AudioUtils.res_type_list():
- raise argparse.ArgumentTypeError(f"{res_type_value} not valid. Supported only {AudioUtils.res_type_list()}")
+ if res_type_value not in GenUtils.res_type_list():
+ raise argparse.ArgumentTypeError(
+ f"{res_type_value} not valid. Supported only {GenUtils.res_type_list()}"
+ )
return res_type_value
@staticmethod
@@ -39,27 +61,18 @@ class AudioUtils:
return ['kaiser_best', 'kaiser_fast']
@staticmethod
- def load_resample_audio_clip(path, target_sr=16000, mono=True, offset=0.0, duration=0, res_type='kaiser_best',
- min_len=16000):
+ def read_audio_file(
+ path,
+ offset,
+ duration
+ ) -> AudioSample:
"""
- Load and resample an audio clip with the given desired specs.
+ Reads an audio file to an array
- Parameters:
- ----------
- path (string): Path to the input audio clip.
- target_sr (int, optional): Target sampling rate. Positive number are considered valid,
- if zero or negative the native sampling rate of the file will be preserved. Default is 16000.
- mono (bool, optional): Specify if the audio file needs to be converted to mono. Default is True.
- offset (float, optional): Target sampling rate. Default is 0.0.
- duration (int, optional): Target duration. Positive number are considered valid,
- if zero or negative the duration of the file will be preserved. Default is 0.
- res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
- min_len (int, optional): Minimun lenght of the output audio time series. Default is 16000.
-
- Returns:
- ----------
- y (np.ndarray): Output audio time series of shape shape=(n,) or (2, n).
- sr (int): A scalar number > 0 that represent the sampling rate of `y`
+ @param path: Path to audio file
+ @param offset: Offset to read from
+ @param duration: Duration to read
+ @return: The audio data and the sample rate
"""
try:
with sf.SoundFile(path) as audio_file:
@@ -76,40 +89,115 @@ class AudioUtils:
# Load the target number of frames
y = audio_file.read(frames=num_frame_duration, dtype=np.float32, always_2d=False).T
-
- except:
+ except OSError as err:
print(f"Failed to open {path} as an audio.")
+ raise err
+
+ return AudioSample(y, origin_sr)
+
+ @staticmethod
+ def _resample_audio(
+ y,
+ target_sr,
+ origin_sr,
+ res_type
+ ):
+ """
+ Resamples audio to a different sample rate
+
+ @param y: Audio to resample
+ @param target_sr: Target sample rate
+ @param origin_sr: Original sample rate
+ @param res_type: Resample type
+ @return: The resampled audio
+ """
+ ratio = float(target_sr) / origin_sr
+ axis = -1
+ n_samples = int(np.ceil(y.shape[axis] * ratio))
+
+ # Resample using resampy
+ y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
+ n_rs_samples = y_rs.shape[axis]
+
+ # Adjust the size
+ if n_rs_samples > n_samples:
+ slices = [slice(None)] * y_rs.ndim
+ slices[axis] = slice(0, n_samples)
+ y = y_rs[tuple(slices)]
+ elif n_rs_samples < n_samples:
+ lengths = [(0, 0)] * y_rs.ndim
+ lengths[axis] = (0, n_samples - n_rs_samples)
+ y = np.pad(y_rs, lengths, 'constant', constant_values=0)
+
+ return y
+
+ @staticmethod
+ def resample_audio_clip(
+ audio_sample: AudioSample,
+ target_sr=16000,
+ mono=True,
+ res_type='kaiser_best',
+ min_len=16000
+ ) -> AudioSample:
+ """
+ Load and resample an audio clip with the given desired specs.
+
+ Parameters:
+ ----------
+ path (string): Path to the input audio clip.
+ target_sr (int, optional): Target sampling rate. Positive number are considered valid,
+ if zero or negative the native sampling rate of the file
+ will be preserved. Default is 16000.
+ mono (bool, optional): Specify if the audio file needs to be converted to mono.
+ Default is True.
+ offset (float, optional): Target sampling rate. Default is 0.0.
+ duration (int, optional): Target duration. Positive number are considered valid,
+ if zero or negative the duration of the file
+ will be preserved. Default is 0.
+ res_type (int, optional): Resample type to use, Default is 'kaiser_best'.
+ min_len (int, optional): Minimum length of the output audio time series.
+ Default is 16000.
+
+ Returns:
+ ----------
+ y (np.ndarray): Output audio time series of shape=(n,) or (2, n).
+ sample_rate (int): A scalar number > 0 that represent the sampling rate of `y`
+ """
+ y = audio_sample.data.copy()
# Convert to mono if requested and if audio has more than one dimension
- if mono and (y.ndim > 1):
+ if mono and (audio_sample.data.ndim > 1):
y = np.mean(y, axis=0)
- if not (origin_sr == target_sr) and (target_sr > 0):
- ratio = float(target_sr) / origin_sr
- axis = -1
- n_samples = int(np.ceil(y.shape[axis] * ratio))
-
- # Resample using resampy
- y_rs = resampy.resample(y, origin_sr, target_sr, filter=res_type, axis=axis)
- n_rs_samples = y_rs.shape[axis]
-
- # Adjust the size
- if n_rs_samples > n_samples:
- slices = [slice(None)] * y_rs.ndim
- slices[axis] = slice(0, n_samples)
- y = y_rs[tuple(slices)]
- elif n_rs_samples < n_samples:
- lengths = [(0, 0)] * y_rs.ndim
- lengths[axis] = (0, n_samples - n_rs_samples)
- y = np.pad(y_rs, lengths, 'constant', constant_values=(0))
-
- sr = target_sr
+ if not (audio_sample.sample_rate == target_sr) and (target_sr > 0):
+ y = GenUtils._resample_audio(y, target_sr, audio_sample.sample_rate, res_type)
+ sample_rate = target_sr
else:
- sr = origin_sr
+ sample_rate = audio_sample.sample_rate
# Pad if necessary and min lenght is setted (min_len> 0)
if (y.shape[0] < min_len) and (min_len > 0):
sample_to_pad = min_len - y.shape[0]
- y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=(0))
+ y = np.pad(y, (0, sample_to_pad), 'constant', constant_values=0)
+
+ return AudioSample(data=y, sample_rate=sample_rate)
- return y, sr
+ @staticmethod
+ def gen_header(
+ env: jinja2.Environment,
+ header_template_file: str,
+ file_name: str = None
+ ) -> str:
+ """
+ Generate common licence header
+
+ :param env: Jinja2 environment
+ :param header_template_file: Path to the licence header template
+ :param file_name: Optional generating script file name
+ :return: Generated licence header as a string
+ """
+ header_template = env.get_template(header_template_file)
+ return header_template.render(script_name=Path(__file__).name,
+ gen_time=datetime.datetime.now(),
+ file_name=file_name,
+ year=datetime.datetime.now().year)
diff --git a/scripts/py/git_pre_push_hooks.sh b/scripts/py/git_pre_push_hooks.sh
new file mode 100755
index 0000000..db5706f
--- /dev/null
+++ b/scripts/py/git_pre_push_hooks.sh
@@ -0,0 +1,48 @@
+#!/bin/sh
+# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Called by "git push" with no arguments. The hook should
+# exit with non-zero status after issuing an appropriate message if
+# it wants to stop the push.
+
+# shellcheck disable=SC2034,SC2162
+while read local_ref local_sha remote_ref remote_sha; do
+ # We should pass only added or modified C/C++ source files to cppcheck.
+ changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2)
+ if [ -n "$changed_files" ]; then
+ # shellcheck disable=SC2086
+ clang-format -style=file --dry-run --Werror $changed_files
+
+ exitcode1=$?
+ if [ $exitcode1 -ne 0 ]; then
+ echo "Formatting errors found in file: $changed_files. \
+ Please run:
+ \"clang-format -style=file -i $changed_files\"
+ to correct these errors"
+ exit $exitcode1
+ fi
+
+ # shellcheck disable=SC2086
+ cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files
+ exitcode2=$?
+ if [ $exitcode2 -ne 0 ]; then
+ exit $exitcode2
+ fi
+ fi
+ exit 0
+done
+
+exit 0
diff --git a/scripts/py/rnnoise_dump_extractor.py b/scripts/py/rnnoise_dump_extractor.py
index 715b922..9e6ff1f 100644
--- a/scripts/py/rnnoise_dump_extractor.py
+++ b/scripts/py/rnnoise_dump_extractor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,47 +20,84 @@ Example use:
python rnnoise_dump_extractor.py --dump_file output.bin --output_dir ./denoised_wavs/
"""
-import soundfile as sf
-import numpy as np
-
import argparse
-from os import path
import struct
+import typing
+from os import path
+
+import numpy as np
+import soundfile as sf
-def extract(fp, output_dir, export_npy):
+def extract(
+ dump_file: typing.IO,
+ output_dir: str,
+ export_npy: bool
+):
+ """
+ Extract audio file from RNNoise output dump
+
+ @param dump_file: Audio dump file location
+ @param output_dir: Output direction
+ @param export_npy: Whether to export the audio as .npy
+ """
while True:
- filename_length = struct.unpack("i", fp.read(4))[0]
+ filename_length = struct.unpack("i", dump_file.read(4))[0]
if filename_length == -1:
return
- filename = struct.unpack("{}s".format(filename_length), fp.read(filename_length))[0].decode('ascii')
- audio_clip_length = struct.unpack("I", fp.read(4))[0]
- output_file_name = path.join(output_dir, "denoised_{}".format(filename))
- audio_clip = fp.read(audio_clip_length)
-
- with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16", endian="LITTLE") as wav_file:
+ filename = struct \
+ .unpack(f"{filename_length}s", dump_file.read(filename_length))[0] \
+ .decode('ascii')
+
+ audio_clip_length = struct.unpack("I", dump_file.read(4))[0]
+ output_file_name = path.join(output_dir, f"denoised_{filename}")
+ audio_clip = dump_file.read(audio_clip_length)
+
+ with sf.SoundFile(output_file_name, 'w', channels=1, samplerate=48000, subtype="PCM_16",
+ endian="LITTLE") as wav_file:
wav_file.buffer_write(audio_clip, dtype='int16')
- print("{} written to disk".format(output_file_name))
+ print(f"{output_file_name} written to disk")
if export_npy:
output_file_name += ".npy"
- pack_format = "{}h".format(int(audio_clip_length/2))
+ pack_format = f"{int(audio_clip_length / 2)}h"
npdata = np.array(struct.unpack(pack_format, audio_clip)).astype(np.int16)
np.save(output_file_name, npdata)
- print("{} written to disk".format(output_file_name))
+ print(f"{output_file_name} written to disk")
def main(args):
+ """
+ Run RNNoise audio dump extraction
+ @param args: Parsed args
+ """
extract(args.dump_file, args.output_dir, args.export_npy)
parser = argparse.ArgumentParser()
-parser.add_argument("--dump_file", type=argparse.FileType('rb'), help="Dump file with audio files to extract.", required=True)
-parser.add_argument("--output_dir", help="Output directory, Warning: Duplicated file names will be overwritten.", required=True)
-parser.add_argument("--export_npy", help="Export the audio buffer in NumPy format", action="store_true")
-args = parser.parse_args()
+
+parser.add_argument(
+ "--dump_file",
+ type=argparse.FileType('rb'),
+ help="Dump file with audio files to extract.",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory, Warning: Duplicated file names will be overwritten.",
+ required=True
+)
+
+parser.add_argument(
+ "--export_npy",
+ help="Export the audio buffer in NumPy format",
+ action="store_true"
+)
+
+parsed_args = parser.parse_args()
if __name__ == "__main__":
- main(args)
+ main(parsed_args)
diff --git a/scripts/py/setup_hooks.py b/scripts/py/setup_hooks.py
index ead5e1f..dc3156c 100644
--- a/scripts/py/setup_hooks.py
+++ b/scripts/py/setup_hooks.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
-# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,84 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import os
-import sys
+"""
+Adds the git hooks script into the appropriate location
+"""
import argparse
+import os
+import shutil
import subprocess
-import stat
+import sys
+from pathlib import Path
+
+HOOKS_SCRIPT = "git_pre_push_hooks.sh"
+
-def set_hooks_dir(hooks_dir):
- command = 'git config core.hooksPath {}'.format(hooks_dir)
- subprocess.Popen(command.split(), stdout=subprocess.PIPE)
+def set_hooks_dir(hooks_dir: str):
+ """
+ Set the hooks path in the git configuration
+ @param hooks_dir: The hooks directory
+ """
+ command = f'git config core.hooksPath {hooks_dir}'
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE) as process:
+ process.communicate()
+ return_code = process.returncode
-def add_pre_push_hooks(hooks_dir):
+ if return_code != 0:
+ raise RuntimeError(f"Could not configure git hooks path, exited with code {return_code}")
+
+
+def add_pre_push_hooks(hooks_dir: str):
+ """
+ Copies the git hooks scripts into the specified location
+ @param hooks_dir: The specified git hooks directory
+ """
pre_push = "pre-push"
file_path = os.path.join(hooks_dir, pre_push)
file_exists = os.path.exists(file_path)
if file_exists:
os.remove(file_path)
- f = open(file_path, "a")
- f.write(
-'''#!/bin/sh
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Called by "git push" with no arguments. The hook should
-# exit with non-zero status after issuing an appropriate message if
-# it wants to stop the push.
-
-while read local_ref local_sha remote_ref remote_sha
-do
- # We should pass only added or modified C/C++ source files to cppcheck.
- changed_files=$(git diff --name-only HEAD~1 HEAD | grep -iE "\.(c|cpp|cxx|cc|h|hpp|hxx)$" | cut -f 2)
- if [ -n "$changed_files" ]; then
- clang-format -style=file --dry-run --Werror $changed_files
-
- exitcode1=$?
- if [ $exitcode1 -ne 0 ]; then
- echo "Formatting errors found in file: $changed_files.
- \nPlease run:\n\ \"clang-format -style=file -i $changed_files\"
- \nto correct these errors"
- exit $exitcode1
- fi
-
- cppcheck --enable=performance,portability --error-exitcode=1 --suppress=*:tests* $changed_files
- exitcode2=$?
- if [ $exitcode2 -ne 0 ]; then
- exit $exitcode2
- fi
- fi
- exit 0
-done
-exit 0'''
-)
+ script_path = Path(__file__).resolve().parent / HOOKS_SCRIPT
+ shutil.copy(script_path, hooks_dir)
- f.close()
- s = os.stat(file_path)
- os.chmod(file_path, s.st_mode | stat.S_IEXEC)
-parser = argparse.ArgumentParser()
-parser.add_argument("git_hooks_path")
-args = parser.parse_args()
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("git_hooks_path")
+ args = parser.parse_args()
-dir_exists = os.path.exists(args.git_hooks_path)
-if not dir_exists:
- print('Error! The Git hooks directory you supplied does not exist.')
- sys.exit()
+ if not os.path.exists(args.git_hooks_path):
+ print('Error! The Git hooks directory you supplied does not exist.')
+ sys.exit()
-add_pre_push_hooks(args.git_hooks_path)
-set_hooks_dir(args.git_hooks_path)
+ add_pre_push_hooks(args.git_hooks_path)
+ set_hooks_dir(args.git_hooks_path)
diff --git a/scripts/py/templates/header_template.txt b/scripts/py/templates/header_template.txt
index f6e3bdb..32bf71a 100644
--- a/scripts/py/templates/header_template.txt
+++ b/scripts/py/templates/header_template.txt
@@ -16,6 +16,6 @@
*/
/********************* Autogenerated file. DO NOT EDIT *******************
- * Generated from {{script_name}} tool {% if file_name %}and {{file_name}}{% endif %} file.
+ * Generated from {{script_name}} tool {% if file_name %}and {{file_name}} {% endif %}file.
* Date: {{gen_time}}
***************************************************************************/
diff --git a/scripts/py/use_case_resources.json b/scripts/py/use_case_resources.json
new file mode 100644
index 0000000..80fa28d
--- /dev/null
+++ b/scripts/py/use_case_resources.json
@@ -0,0 +1,190 @@
+[
+ {
+ "name": "ad",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "ad_medium_int8.tflite",
+ "url": "{url_prefix:0}ad_medium_int8.tflite"
+ },
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}
+ ]
+ },
+ {
+ "name": "asr",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/"
+ ],
+ "resources": [
+ {
+ "name": "wav2letter_pruned_int8.tflite",
+ "url": "{url_prefix:0}wav2letter_pruned_int8.tflite"
+ },
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/input_2_int8/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "img_class",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "mobilenet_v2_1.0_224_INT8.tflite",
+ "url": "{url_prefix:0}mobilenet_v2_1.0_224_INT8.tflite"
+ },
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/tfl.quantize/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/MobilenetV2/Predictions/Reshape_11/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "object_detection",
+ "url_prefix": [
+ "https://github.com/emza-vs/ModelZoo/blob/v1.0/object_detection/"
+ ],
+ "resources": [
+ {
+ "name": "yolo-fastest_192_face_v4.tflite",
+ "url": "{url_prefix:0}yolo-fastest_192_face_v4.tflite?raw=true"
+ }
+ ]
+ },
+ {
+ "name": "kws",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"},
+ {
+ "name": "kws_micronet_m.tflite",
+ "url": "{url_prefix:0}kws_micronet_m.tflite"
+ }
+ ]
+ },
+ {
+ "name": "vww",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "vww4_128_128_INT8.tflite",
+ "url": "{url_prefix:0}vww4_128_128_INT8.tflite"
+ },
+ {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
+ {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"}
+ ]
+ },
+ {
+ "name": "kws_asr",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/",
+ "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "wav2letter_pruned_int8.tflite",
+ "url": "{url_prefix:0}wav2letter_pruned_int8.tflite"
+ },
+ {
+ "sub_folder": "asr",
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/input_2_int8/0.npy"
+ },
+ {
+ "sub_folder": "asr",
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ },
+ {
+ "sub_folder": "kws",
+ "name": "ifm0.npy",
+ "url": "{url_prefix:1}testing_input/input/0.npy"
+ },
+ {
+ "sub_folder": "kws",
+ "name": "ofm0.npy",
+ "url": "{url_prefix:1}testing_output/Identity/0.npy"
+ },
+ {
+ "name": "kws_micronet_m.tflite",
+ "url": "{url_prefix:1}kws_micronet_m.tflite"
+ }
+ ]
+ },
+ {
+ "name": "noise_reduction",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/"
+ ],
+ "resources": [
+ {"name": "rnnoise_INT8.tflite", "url": "{url_prefix:0}rnnoise_INT8.tflite"},
+ {
+ "name": "ifm0.npy",
+ "url": "{url_prefix:0}testing_input/main_input_int8/0.npy"
+ },
+ {
+ "name": "ifm1.npy",
+ "url": "{url_prefix:0}testing_input/vad_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ifm2.npy",
+ "url": "{url_prefix:0}testing_input/noise_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ifm3.npy",
+ "url": "{url_prefix:0}testing_input/denoise_gru_prev_state_int8/0.npy"
+ },
+ {
+ "name": "ofm0.npy",
+ "url": "{url_prefix:0}testing_output/Identity_int8/0.npy"
+ },
+ {
+ "name": "ofm1.npy",
+ "url": "{url_prefix:0}testing_output/Identity_1_int8/0.npy"
+ },
+ {
+ "name": "ofm2.npy",
+ "url": "{url_prefix:0}testing_output/Identity_2_int8/0.npy"
+ },
+ {
+ "name": "ofm3.npy",
+ "url": "{url_prefix:0}testing_output/Identity_3_int8/0.npy"
+ },
+ {
+ "name": "ofm4.npy",
+ "url": "{url_prefix:0}testing_output/Identity_4_int8/0.npy"
+ }
+ ]
+ },
+ {
+ "name": "inference_runner",
+ "url_prefix": [
+ "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/"
+ ],
+ "resources": [
+ {
+ "name": "dnn_s_quantized.tflite",
+ "url": "{url_prefix:0}dnn_s_quantized.tflite"
+ }
+ ]
+ }
+]
diff --git a/set_up_default_resources.py b/set_up_default_resources.py
index f983508..a3987bc 100755
--- a/set_up_default_resources.py
+++ b/set_up_default_resources.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Script to set up default resources for ML Embedded Evaluation Kit
+"""
+import dataclasses
import errno
import fnmatch
import json
@@ -22,207 +26,21 @@ import re
import shutil
import subprocess
import sys
+import typing
import urllib.request
import venv
from argparse import ArgumentParser
from argparse import ArgumentTypeError
from collections import namedtuple
-from urllib.error import URLError
+from dataclasses import dataclass
from pathlib import Path
+from urllib.error import URLError
from scripts.py.check_update_resources_downloaded import get_md5sum_for_file
-
-json_uc_res = [
- {
- "use_case_name": "ad",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/"
- ],
- "resources": [
- {
- "name": "ad_medium_int8.tflite",
- "url": "{url_prefix:0}ad_medium_int8.tflite",
- },
- {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
- {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"},
- ],
- },
- {
- "use_case_name": "asr",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/"
- ],
- "resources": [
- {
- "name": "wav2letter_pruned_int8.tflite",
- "url": "{url_prefix:0}wav2letter_pruned_int8.tflite",
- },
- {
- "name": "ifm0.npy",
- "url": "{url_prefix:0}testing_input/input_2_int8/0.npy",
- },
- {
- "name": "ofm0.npy",
- "url": "{url_prefix:0}testing_output/Identity_int8/0.npy",
- },
- ],
- },
- {
- "use_case_name": "img_class",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/"
- ],
- "resources": [
- {
- "name": "mobilenet_v2_1.0_224_INT8.tflite",
- "url": "{url_prefix:0}mobilenet_v2_1.0_224_INT8.tflite",
- },
- {
- "name": "ifm0.npy",
- "url": "{url_prefix:0}testing_input/tfl.quantize/0.npy",
- },
- {
- "name": "ofm0.npy",
- "url": "{url_prefix:0}testing_output/MobilenetV2/Predictions/Reshape_11/0.npy",
- },
- ],
- },
- {
- "use_case_name": "object_detection",
- "url_prefix": [
- "https://github.com/emza-vs/ModelZoo/blob/v1.0/object_detection/"
- ],
- "resources": [
- {
- "name": "yolo-fastest_192_face_v4.tflite",
- "url": "{url_prefix:0}yolo-fastest_192_face_v4.tflite?raw=true",
- }
- ],
- },
- {
- "use_case_name": "kws",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/"
- ],
- "resources": [
- {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
- {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"},
- {
- "name": "kws_micronet_m.tflite",
- "url": "{url_prefix:0}kws_micronet_m.tflite",
- },
- ],
- },
- {
- "use_case_name": "vww",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/"
- ],
- "resources": [
- {
- "name": "vww4_128_128_INT8.tflite",
- "url": "{url_prefix:0}vww4_128_128_INT8.tflite",
- },
- {"name": "ifm0.npy", "url": "{url_prefix:0}testing_input/input/0.npy"},
- {"name": "ofm0.npy", "url": "{url_prefix:0}testing_output/Identity/0.npy"},
- ],
- },
- {
- "use_case_name": "kws_asr",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/",
- "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/",
- ],
- "resources": [
- {
- "name": "wav2letter_pruned_int8.tflite",
- "url": "{url_prefix:0}wav2letter_pruned_int8.tflite",
- },
- {
- "sub_folder": "asr",
- "name": "ifm0.npy",
- "url": "{url_prefix:0}testing_input/input_2_int8/0.npy",
- },
- {
- "sub_folder": "asr",
- "name": "ofm0.npy",
- "url": "{url_prefix:0}testing_output/Identity_int8/0.npy",
- },
- {
- "sub_folder": "kws",
- "name": "ifm0.npy",
- "url": "{url_prefix:1}testing_input/input/0.npy",
- },
- {
- "sub_folder": "kws",
- "name": "ofm0.npy",
- "url": "{url_prefix:1}testing_output/Identity/0.npy",
- },
- {
- "name": "kws_micronet_m.tflite",
- "url": "{url_prefix:1}kws_micronet_m.tflite",
- },
- ],
- },
- {
- "use_case_name": "noise_reduction",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/"
- ],
- "resources": [
- {"name": "rnnoise_INT8.tflite", "url": "{url_prefix:0}rnnoise_INT8.tflite"},
- {
- "name": "ifm0.npy",
- "url": "{url_prefix:0}testing_input/main_input_int8/0.npy",
- },
- {
- "name": "ifm1.npy",
- "url": "{url_prefix:0}testing_input/vad_gru_prev_state_int8/0.npy",
- },
- {
- "name": "ifm2.npy",
- "url": "{url_prefix:0}testing_input/noise_gru_prev_state_int8/0.npy",
- },
- {
- "name": "ifm3.npy",
- "url": "{url_prefix:0}testing_input/denoise_gru_prev_state_int8/0.npy",
- },
- {
- "name": "ofm0.npy",
- "url": "{url_prefix:0}testing_output/Identity_int8/0.npy",
- },
- {
- "name": "ofm1.npy",
- "url": "{url_prefix:0}testing_output/Identity_1_int8/0.npy",
- },
- {
- "name": "ofm2.npy",
- "url": "{url_prefix:0}testing_output/Identity_2_int8/0.npy",
- },
- {
- "name": "ofm3.npy",
- "url": "{url_prefix:0}testing_output/Identity_3_int8/0.npy",
- },
- {
- "name": "ofm4.npy",
- "url": "{url_prefix:0}testing_output/Identity_4_int8/0.npy",
- },
- ],
- },
- {
- "use_case_name": "inference_runner",
- "url_prefix": [
- "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/"
- ],
- "resources": [
- {
- "name": "dnn_s_quantized.tflite",
- "url": "{url_prefix:0}dnn_s_quantized.tflite",
- }
- ],
- },
-]
+# Supported version of Python and Vela
+VELA_VERSION = "3.9.0"
+py3_version_minimum = (3, 9)
# Valid NPU configurations:
valid_npu_config_names = [
@@ -250,10 +68,57 @@ NPUConfig = namedtuple(
],
)
+
+@dataclass(frozen=True)
+class UseCaseResource:
+ """
+ Represent a use case's resource
+ """
+ name: str
+ url: str
+ sub_folder: typing.Optional[str] = None
+
+
+@dataclass(frozen=True)
+class UseCase:
+ """
+ Represent a use case
+ """
+ name: str
+ url_prefix: str
+ resources: typing.List[UseCaseResource]
+
+
# The internal SRAM size for Corstone-300 implementation on MPS3 specified by AN552
# The internal SRAM size for Corstone-310 implementation on MPS3 specified by AN555
# is 4MB, but we are content with the 2MB specified below.
-mps3_max_sram_sz = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each)
+MPS3_MAX_SRAM_SZ = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each)
+
+
+def load_use_case_resources(current_file_dir: Path) -> typing.List[UseCase]:
+ """
+ Load use case metadata resources
+
+ Parameters
+ ----------
+ current_file_dir: Directory of the current script
+
+ Returns
+ -------
+ The use cases resources object parsed to a dict
+ """
+
+ resources_path = current_file_dir / "scripts" / "py" / "use_case_resources.json"
+ with open(resources_path, encoding="utf8") as f:
+ use_cases = json.load(f)
+ return [
+ UseCase(
+ name=u["name"],
+ url_prefix=u["url_prefix"],
+ resources=[UseCaseResource(**r) for r in u["resources"]],
+ )
+ for u in use_cases
+ ]
def call_command(command: str, verbose: bool = True) -> str:
@@ -266,21 +131,22 @@ def call_command(command: str, verbose: bool = True) -> str:
"""
if verbose:
logging.info(command)
- proc = subprocess.run(
- command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
- )
- log = proc.stdout.decode("utf-8")
- if proc.returncode == 0 and verbose:
+ try:
+ proc = subprocess.run(
+ command, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
+ )
+ log = proc.stdout.decode("utf-8")
logging.info(log)
- else:
+ return log
+ except subprocess.CalledProcessError as err:
+ log = err.stdout.decode("utf-8")
logging.error(log)
- proc.check_returncode()
- return log
+ raise err
def get_default_npu_config_from_name(
- config_name: str, arena_cache_size: int = 0
-) -> NPUConfig:
+ config_name: str, arena_cache_size: int = 0
+) -> typing.Optional[NPUConfig]:
"""
Gets the file suffix for the TFLite file from the
`accelerator_config` string.
@@ -314,15 +180,15 @@ def get_default_npu_config_from_name(
system_configs = ["Ethos_U55_High_End_Embedded", "Ethos_U65_High_End"]
memory_modes_arena = {
# For shared SRAM memory mode, we use the MPS3 SRAM size by default.
- "Shared_Sram": mps3_max_sram_sz if arena_cache_size <= 0 else arena_cache_size,
+ "Shared_Sram": MPS3_MAX_SRAM_SZ if arena_cache_size <= 0 else arena_cache_size,
# For dedicated SRAM memory mode, we do not override the arena size. This is expected to
# be defined in the Vela configuration file instead.
"Dedicated_Sram": None if arena_cache_size <= 0 else arena_cache_size,
}
- for i in range(len(strings_ids)):
- if config_name.startswith(strings_ids[i]):
- npu_config_id = config_name.replace(strings_ids[i], prefix_ids[i])
+ for i, string_id in enumerate(strings_ids):
+ if config_name.startswith(string_id):
+ npu_config_id = config_name.replace(string_id, prefix_ids[i])
return NPUConfig(
config_name=config_name,
memory_mode=memory_modes[i],
@@ -335,97 +201,316 @@ def get_default_npu_config_from_name(
return None
-def remove_tree_dir(dir_path):
+def remove_tree_dir(dir_path: Path):
+ """
+ Delete and re-create a directory
+
+ Parameters
+ ----------
+ dir_path : The directory path
+ """
try:
# Remove the full directory.
shutil.rmtree(dir_path)
# Re-create an empty one.
os.mkdir(dir_path)
- except Exception as e:
- logging.error(f"Failed to delete {dir_path}.")
+ except OSError:
+ logging.error("Failed to delete %s.", dir_path)
-def set_up_resources(
- run_vela_on_models: bool = False,
- additional_npu_config_names: tuple = (),
- arena_cache_size: int = 0,
- check_clean_folder: bool = False,
- additional_requirements_file: str = "") -> (Path, Path):
+def initialize_use_case_resources_directory(
+ use_case: UseCase,
+ metadata: typing.Dict,
+ download_dir: Path,
+ check_clean_folder: bool,
+ setup_script_hash_verified: bool,
+):
"""
- Helpers function that retrieve the output from a command.
+ Initialize the resources_downloaded directory for a use case
- Parameters:
- ----------
- run_vela_on_models (bool): Specifies if run vela on downloaded models.
- additional_npu_config_names(list): list of strings of Ethos-U NPU configs.
- arena_cache_size (int): Specifies arena cache size in bytes. If a value
- greater than 0 is provided, this will be taken
- as the cache size. If 0, the default values, as per
- the NPU config requirements, are used.
- check_clean_folder (bool): Indicates whether the resources folder needs to
- be checked for updates and cleaned.
- additional_requirements_file (str): Path to a requirements.txt file if
- additional packages need to be
- installed.
+ @param use_case: The use case
+ @param metadata: The metadata
+ @param download_dir: The parent directory
+ @param check_clean_folder: Whether to clean the folder
+ @param setup_script_hash_verified: Whether the hash of this script is verified
+ """
+ try:
+ # Does the usecase_name download dir exist?
+ (download_dir / use_case.name).mkdir()
+ except OSError as err:
+ if err.errno == errno.EEXIST:
+ # The usecase_name download dir exist.
+ if check_clean_folder and not setup_script_hash_verified:
+ for idx, metadata_uc_url_prefix in enumerate(
+ [
+ f
+ for f in metadata["resources_info"]
+ if f["name"] == use_case.name
+ ][0]["url_prefix"]
+ ):
+ if metadata_uc_url_prefix != use_case.url_prefix[idx]:
+ logging.info("Removing %s resources.", use_case.name)
+ remove_tree_dir(download_dir / use_case.name)
+ break
+ elif err.errno != errno.EEXIST:
+ logging.error("Error creating %s directory.", use_case.name)
+ raise
- Returns
- -------
- Tuple of pair of Paths: (download_directory_path, virtual_env_path)
+def download_file(url: str, dest: Path):
+ """
+ Download a file
- download_directory_path: Root of the directory where the resources have been downloaded to.
- virtual_env_path: Path to the root of virtual environment.
+ @param url: The URL of the file to download
+ @param dest: The destination of downloaded file
"""
- # Paths.
- current_file_dir = Path(__file__).parent.resolve()
- download_dir = current_file_dir / "resources_downloaded"
- metadata_file_path = download_dir / "resources_downloaded_metadata.json"
+ try:
+ with urllib.request.urlopen(url) as g:
+ with open(dest, "b+w") as f:
+ f.write(g.read())
+ logging.info("- Downloaded %s to %s.", url, dest)
+ except URLError:
+ logging.error("URLError while downloading %s.", url)
+ raise
+
+
+def download_resources(
+ use_case: UseCase,
+ metadata: typing.Dict,
+ download_dir: Path,
+ check_clean_folder: bool,
+ setup_script_hash_verified: bool,
+):
+ """
+ Download the resources associated with a use case
- metadata_dict = dict()
- vela_version = "3.9.0"
- py3_version_minimum = (3, 9)
+ @param use_case: The use case
+ @param metadata: The metadata
+ @param download_dir: The parent directory
+ @param check_clean_folder: Whether to clean the folder
+ @param setup_script_hash_verified: Whether the hash is already verified
+ """
+ initialize_use_case_resources_directory(
+ use_case,
+ metadata,
+ download_dir,
+ check_clean_folder,
+ setup_script_hash_verified
+ )
- # Is Python minimum requirement matched?
- py3_version = sys.version_info
- if py3_version < py3_version_minimum:
- raise Exception(
- "ERROR: Python3.9+ is required, please see the documentation on how to update it."
+ reg_expr_str = r"{url_prefix:(.*\d)}"
+ reg_expr_pattern = re.compile(reg_expr_str)
+ for res in use_case.resources:
+ res_name = res.name
+ url_prefix_idx = int(reg_expr_pattern.search(res.url).group(1))
+ res_url = use_case.url_prefix[url_prefix_idx] + re.sub(
+ reg_expr_str, "", res.url
+ )
+
+ sub_folder = ""
+ if res.sub_folder is not None:
+ try:
+ # Does the usecase_name/sub_folder download dir exist?
+ (download_dir / use_case.name / res.sub_folder).mkdir()
+ except OSError as err:
+ if err.errno != errno.EEXIST:
+ logging.error(
+ "Error creating %s/%s directory.",
+ use_case.name,
+ res.sub_folder
+ )
+ raise
+ sub_folder = res.sub_folder
+
+ res_dst = download_dir / use_case.name / sub_folder / res_name
+
+ if res_dst.is_file():
+ logging.info("File %s exists, skipping download.", res_dst)
+ else:
+ download_file(res_url, res_dst)
+
+
+def run_vela(
+ config: NPUConfig,
+ env_activate_cmd: str,
+ model: Path,
+ config_file: Path,
+ output_dir: Path
+) -> bool:
+ """
+ Run vela on the specified model
+ @param config: The NPU configuration
+ @param env_activate_cmd: The Python venv activation command
+ @param model: The model
+ @param config_file: The vela config file
+ @param output_dir: The output directory
+ @return: True if the optimisation was skipped, false otherwise
+ """
+ # model name after compiling with vela is an initial model name + _vela suffix
+ vela_optimised_model_path = model.parent / (model.stem + "_vela.tflite")
+
+ vela_command_arena_cache_size = ""
+
+ if config.arena_cache_size:
+ vela_command_arena_cache_size = (
+ f"--arena-cache-size={config.arena_cache_size}"
+ )
+
+ vela_command = (
+ f"{env_activate_cmd} && vela {model} "
+ + f"--accelerator-config={config.config_name} "
+ + "--optimise Performance "
+ + f"--config {config_file} "
+ + f"--memory-mode={config.memory_mode} "
+ + f"--system-config={config.system_config} "
+ + f"--output-dir={output_dir} "
+ + f"{vela_command_arena_cache_size}"
+ )
+
+ # We want the name to include the configuration suffix. For example: vela_H128,
+ # vela_Y512 etc.
+ new_suffix = "_vela_" + config.ethos_u_config_id + ".tflite"
+ new_vela_optimised_model_path = model.parent / (model.stem + new_suffix)
+
+ skip_optimisation = new_vela_optimised_model_path.is_file()
+
+ if skip_optimisation:
+ logging.info(
+ "File %s exists, skipping optimisation.",
+ new_vela_optimised_model_path
)
else:
- logging.info(f"Using Python version: {py3_version}")
+ call_command(vela_command)
+
+ # Rename default vela model.
+ vela_optimised_model_path.rename(new_vela_optimised_model_path)
+ logging.info(
+ "Renaming %s to %s.",
+ vela_optimised_model_path,
+ new_vela_optimised_model_path
+ )
+ return skip_optimisation
+
+
+def run_vela_on_all_models(
+ current_file_dir: Path,
+ download_dir: Path,
+ env_activate_cmd: str,
+ arena_cache_size: int,
+ npu_config_names: typing.List[str]
+):
+ """
+ Run vela on downloaded models for the specified NPU configurations
+
+ @param current_file_dir: Path to the current directory
+ @param download_dir: Path to the downloaded resources directory
+ @param env_activate_cmd: Command used to activate Python venv
+ @param npu_config_names: Names of NPU configurations for which to run Vela
+ @param arena_cache_size: The arena cache size
+ """
+ config_file = current_file_dir / "scripts" / "vela" / "default_vela.ini"
+ models = [
+ Path(dirpath) / f
+ for dirpath, dirnames, files in os.walk(download_dir)
+ for f in fnmatch.filter(files, "*.tflite")
+ if "vela" not in f
+ ]
+
+ # Get npu config tuple for each config name in a list:
+ npu_configs = [
+ get_default_npu_config_from_name(name, arena_cache_size)
+ for name in npu_config_names
+ ]
+
+ logging.info("All models will be optimised for these configs:")
+ for config in npu_configs:
+ logging.info(config)
+
+ optimisation_skipped = False
+
+ for model in models:
+ for config in npu_configs:
+ optimisation_skipped = run_vela(
+ config,
+ env_activate_cmd,
+ model,
+ config_file,
+ output_dir=model.parent
+ ) or optimisation_skipped
+
+ # If any optimisation was skipped, show how to regenerate:
+ if optimisation_skipped:
+ logging.warning("One or more optimisations were skipped.")
+ logging.warning(
+ "To optimise all the models, please remove the directory %s.",
+ download_dir
+ )
+
+
+def initialize_resources_directory(
+ download_dir: Path,
+ check_clean_folder: bool,
+ metadata_file_path: Path,
+ setup_script_hash: str
+) -> typing.Tuple[typing.Dict, bool]:
+ """
+ Sets up the resources_downloaded directory and checks to see if this script
+ has been modified since the last time resources were downloaded
+
+ @param download_dir: Path to the resources_downloaded directory
+ @param check_clean_folder: Determines whether to clean the downloads directory
+ @param metadata_file_path: Path to the metadata file
+ @param setup_script_hash: The md5 hash of this script
+ @return: The metadata and a boolean to indicate whether this
+ script has changed since it was last run
+ """
+ metadata_dict = {}
setup_script_hash_verified = False
- setup_script_hash = get_md5sum_for_file(Path(__file__).resolve())
- try:
- # 1.1 Does the download dir exist?
+ if download_dir.is_dir():
+ logging.info("'resources_downloaded' directory exists.")
+ # Check and clean?
+ if check_clean_folder and metadata_file_path.is_file():
+ with open(metadata_file_path, encoding="utf8") as metadata_file:
+ metadata_dict = json.load(metadata_file)
+
+ vela_in_metadata = metadata_dict["ethosu_vela_version"]
+ if vela_in_metadata != VELA_VERSION:
+ # Check if all the resources needs to be removed and regenerated.
+ # This can happen when the Vela version has changed.
+ logging.info(
+ ("Vela version in metadata is %s, current %s."
+ " Removing the resources and re-download them.",
+ vela_in_metadata,
+ VELA_VERSION
+ )
+ )
+ remove_tree_dir(download_dir)
+ metadata_dict = {}
+ else:
+ # Check if the set_up_default_resorces.py has changed from last setup
+ setup_script_hash_verified = (
+ metadata_dict.get("set_up_script_md5sum")
+ == setup_script_hash
+ )
+ else:
download_dir.mkdir()
- except OSError as e:
- if e.errno == errno.EEXIST:
- logging.info("'resources_downloaded' directory exists.")
- # Check and clean?
- if check_clean_folder and metadata_file_path.is_file():
- with open(metadata_file_path) as metadata_file:
- metadata_dict = json.load(metadata_file)
- vela_in_metadata = metadata_dict["ethosu_vela_version"]
- if vela_in_metadata != vela_version:
- # Check if all the resources needs to be removed and regenerated.
- # This can happen when the Vela version has changed.
- logging.info(
- f"Vela version in metadata is {vela_in_metadata}, current {vela_version}. Removing the resources and re-download them."
- )
- remove_tree_dir(download_dir)
- metadata_dict = dict()
- else:
- # Check if the set_up_default_resorces.py has changed from last setup
- setup_script_hash_verified = (
- metadata_dict.get("set_up_script_md5sum")
- == setup_script_hash
- )
- else:
- raise
- # 1.2 Does the virtual environment exist?
+ return metadata_dict, setup_script_hash_verified
+
+
+def set_up_python_venv(
+ download_dir: Path,
+ additional_requirements_file: Path = ""
+):
+ """
+ Set up the Python environment with which to set up the resources
+
+ @param download_dir: Path to the resources_downloaded directory
+ @param additional_requirements_file: Optional additional requirements file
+ @return: Path to the venv Python binary + activate command
+ """
env_dirname = "env"
env_path = download_dir / env_dirname
@@ -434,23 +519,28 @@ def set_up_resources(
env_python = Path(venv_context.env_exe)
- if sys.platform == "win32":
- env_activate = f"{venv_context.bin_path}/activate.bat"
- else:
- env_activate = f". {venv_context.bin_path}/activate"
-
if not env_python.is_file():
# Create the virtual environment using current interpreter's venv
# (not necessarily the system's Python3)
venv_builder.create(env_dir=env_path)
+ if sys.platform == "win32":
+ env_activate = Path(f"{venv_context.bin_path}/activate.bat")
+ env_activate_cmd = str(env_activate)
+ else:
+ env_activate = Path(f"{venv_context.bin_path}/activate")
+ env_activate_cmd = f". {env_activate}"
+
+ if not env_activate.is_file():
+ venv_builder.install_scripts(venv_context, venv_context.bin_path)
+
# 1.3 Install additional requirements first, if a valid file has been provided
if additional_requirements_file and os.path.isfile(additional_requirements_file):
command = f"{env_python} -m pip install -r {additional_requirements_file}"
call_command(command)
# 1.4 Make sure to have all the main requirements
- requirements = [f"ethos-u-vela=={vela_version}"]
+ requirements = [f"ethos-u-vela=={VELA_VERSION}"]
command = f"{env_python} -m pip freeze"
packages = call_command(command)
for req in requirements:
@@ -458,162 +548,130 @@ def set_up_resources(
command = f"{env_python} -m pip install {req}"
call_command(command)
- # 2. Download models
- logging.info("Downloading resources.")
- for uc in json_uc_res:
- use_case_name = uc["use_case_name"]
- res_url_prefix = uc["url_prefix"]
- try:
- # Does the usecase_name download dir exist?
- (download_dir / use_case_name).mkdir()
- except OSError as e:
- if e.errno == errno.EEXIST:
- # The usecase_name download dir exist.
- if check_clean_folder and not setup_script_hash_verified:
- for idx, metadata_uc_url_prefix in enumerate(
- [
- f
- for f in metadata_dict["resources_info"]
- if f["use_case_name"] == use_case_name
- ][0]["url_prefix"]
- ):
- if metadata_uc_url_prefix != res_url_prefix[idx]:
- logging.info(f"Removing {use_case_name} resources.")
- remove_tree_dir(download_dir / use_case_name)
- break
- elif e.errno != errno.EEXIST:
- logging.error(f"Error creating {use_case_name} directory.")
- raise
-
- reg_expr_str = r"{url_prefix:(.*\d)}"
- reg_expr_pattern = re.compile(reg_expr_str)
- for res in uc["resources"]:
- res_name = res["name"]
- url_prefix_idx = int(reg_expr_pattern.search(res["url"]).group(1))
- res_url = res_url_prefix[url_prefix_idx] + re.sub(
- reg_expr_str, "", res["url"]
- )
+ return env_path, env_activate_cmd
- sub_folder = ""
- if "sub_folder" in res:
- try:
- # Does the usecase_name/sub_folder download dir exist?
- (download_dir / use_case_name / res["sub_folder"]).mkdir()
- except OSError as e:
- if e.errno != errno.EEXIST:
- logging.error(
- f"Error creating {use_case_name} / {res['sub_folder']} directory."
- )
- raise
- sub_folder = res["sub_folder"]
-
- res_dst = download_dir / use_case_name / sub_folder / res_name
-
- if res_dst.is_file():
- logging.info(f"File {res_dst} exists, skipping download.")
- else:
- try:
- g = urllib.request.urlopen(res_url)
- with open(res_dst, "b+w") as f:
- f.write(g.read())
- logging.info(f"- Downloaded {res_url} to {res_dst}.")
- except URLError:
- logging.error(f"URLError while downloading {res_url}.")
- raise
- # 3. Run vela on models in resources_downloaded
- # New models will have same name with '_vela' appended.
- # For example:
- # original model: kws_micronet_m.tflite
- # after vela model: kws_micronet_m_vela_H128.tflite
- #
- # Note: To avoid to run vela twice on the same model, it's supposed that
- # downloaded model names don't contain the 'vela' word.
- if run_vela_on_models is True:
- config_file = current_file_dir / "scripts" / "vela" / "default_vela.ini"
- models = [
- Path(dirpath) / f
- for dirpath, dirnames, files in os.walk(download_dir)
- for f in fnmatch.filter(files, "*.tflite")
- if "vela" not in f
- ]
+def update_metadata(
+ metadata_dict: typing.Dict,
+ setup_script_hash: str,
+ json_uc_res: typing.List[UseCase],
+ metadata_file_path: Path
+):
+ """
+ Update the metadata file
- # Consolidate all config names while discarding duplicates:
- config_names = list(set(default_npu_config_names + additional_npu_config_names))
+ @param metadata_dict: The metadata dictionary to update
+ @param setup_script_hash: The setup script hash
+ @param json_uc_res: The use case resources metadata
+ @param metadata_file_path The metadata file path
+ """
+ metadata_dict["ethosu_vela_version"] = VELA_VERSION
+ metadata_dict["set_up_script_md5sum"] = setup_script_hash.strip("\n")
+ metadata_dict["resources_info"] = [dataclasses.asdict(uc) for uc in json_uc_res]
- # Get npu config tuple for each config name in a list:
- npu_configs = [
- get_default_npu_config_from_name(name, arena_cache_size)
- for name in config_names
- ]
+ with open(metadata_file_path, "w", encoding="utf8") as metadata_file:
+ json.dump(metadata_dict, metadata_file, indent=4)
- logging.info(f"All models will be optimised for these configs:")
- for config in npu_configs:
- logging.info(config)
- optimisation_skipped = False
+def set_up_resources(
+ run_vela_on_models: bool = False,
+ additional_npu_config_names: tuple = (),
+ arena_cache_size: int = 0,
+ check_clean_folder: bool = False,
+ additional_requirements_file: Path = ""
+) -> Path:
+ """
+ Helpers function that retrieve the output from a command.
- for model in models:
- output_dir = model.parent
- # model name after compiling with vela is an initial model name + _vela suffix
- vela_optimised_model_path = model.parent / (model.stem + "_vela.tflite")
+ Parameters:
+ ----------
+ run_vela_on_models (bool): Specifies if run vela on downloaded models.
+ additional_npu_config_names(list): list of strings of Ethos-U NPU configs.
+ arena_cache_size (int): Specifies arena cache size in bytes. If a value
+ greater than 0 is provided, this will be taken
+ as the cache size. If 0, the default values, as per
+ the NPU config requirements, are used.
+ check_clean_folder (bool): Indicates whether the resources folder needs to
+ be checked for updates and cleaned.
+ additional_requirements_file (str): Path to a requirements.txt file if
+ additional packages need to be
+ installed.
- for config in npu_configs:
- vela_command_arena_cache_size = ""
+ Returns
+ -------
- if config.arena_cache_size:
- vela_command_arena_cache_size = (
- f"--arena-cache-size={config.arena_cache_size}"
- )
+ Tuple of pair of Paths: (download_directory_path, virtual_env_path)
- vela_command = (
- f"{env_activate} && vela {model} "
- + f"--accelerator-config={config.config_name} "
- + "--optimise Performance "
- + f"--config {config_file} "
- + f"--memory-mode={config.memory_mode} "
- + f"--system-config={config.system_config} "
- + f"--output-dir={output_dir} "
- + f"{vela_command_arena_cache_size}"
- )
+ download_directory_path: Root of the directory where the resources have been downloaded to.
+ virtual_env_path: Path to the root of virtual environment.
+ """
+ # Paths.
+ current_file_dir = Path(__file__).parent.resolve()
+ download_dir = current_file_dir / "resources_downloaded"
+ metadata_file_path = download_dir / "resources_downloaded_metadata.json"
- # We want the name to include the configuration suffix. For example: vela_H128,
- # vela_Y512 etc.
- new_suffix = "_vela_" + config.ethos_u_config_id + ".tflite"
- new_vela_optimised_model_path = model.parent / (model.stem + new_suffix)
+ # Is Python minimum requirement matched?
+ if sys.version_info < py3_version_minimum:
+ raise RuntimeError(
+ f"ERROR: Python{'.'.join(str(i) for i in py3_version_minimum)}+ is required,"
+ f" please see the documentation on how to update it."
+ )
+ logging.info("Using Python version: %s", sys.version_info)
- if new_vela_optimised_model_path.is_file():
- logging.info(
- f"File {new_vela_optimised_model_path} exists, skipping optimisation."
- )
- optimisation_skipped = True
- continue
+ json_uc_res = load_use_case_resources(current_file_dir)
+ setup_script_hash = get_md5sum_for_file(Path(__file__).resolve())
- call_command(vela_command)
+ metadata_dict, setup_script_hash_verified = initialize_resources_directory(
+ download_dir,
+ check_clean_folder,
+ metadata_file_path,
+ setup_script_hash
+ )
- # Rename default vela model.
- vela_optimised_model_path.rename(new_vela_optimised_model_path)
- logging.info(
- f"Renaming {vela_optimised_model_path} to {new_vela_optimised_model_path}."
- )
+ env_path, env_activate = set_up_python_venv(
+ download_dir,
+ additional_requirements_file
+ )
- # If any optimisation was skipped, show how to regenerate:
- if optimisation_skipped:
- logging.warning("One or more optimisations were skipped.")
- logging.warning(
- f"To optimise all the models, please remove the directory {download_dir}."
- )
+ # 2. Download models
+ logging.info("Downloading resources.")
+ for use_case in json_uc_res:
+ download_resources(
+ use_case,
+ metadata_dict,
+ download_dir,
+ check_clean_folder,
+ setup_script_hash_verified
+ )
+
+ # 3. Run vela on models in resources_downloaded
+ # New models will have same name with '_vela' appended.
+ # For example:
+ # original model: kws_micronet_m.tflite
+ # after vela model: kws_micronet_m_vela_H128.tflite
+ #
+ # Note: To avoid to run vela twice on the same model, it's supposed that
+ # downloaded model names don't contain the 'vela' word.
+ if run_vela_on_models is True:
+ # Consolidate all config names while discarding duplicates:
+ run_vela_on_all_models(
+ current_file_dir,
+ download_dir,
+ env_activate,
+ arena_cache_size,
+ npu_config_names=list(set(default_npu_config_names + list(additional_npu_config_names)))
+ )
# 4. Collect and write metadata
logging.info("Collecting and write metadata.")
- metadata_dict["ethosu_vela_version"] = vela_version
- metadata_dict["set_up_script_md5sum"] = setup_script_hash.strip("\n")
- metadata_dict["resources_info"] = json_uc_res
-
- with open(metadata_file_path, "w") as metadata_file:
- json.dump(metadata_dict, metadata_file, indent=4)
+ update_metadata(
+ metadata_dict,
+ setup_script_hash.strip("\n"),
+ json_uc_res,
+ metadata_file_path
+ )
- return download_dir, env_path
+ return env_path
if __name__ == "__main__":