diff options
author | Alex Tawse <alex.tawse@arm.com> | 2023-09-29 15:55:38 +0100 |
---|---|---|
committer | Richard <richard.burton@arm.com> | 2023-10-26 12:35:48 +0000 |
commit | daba3cf2e3633cbd0e4f8aabe7578b97e88deee1 (patch) | |
tree | 51024b8025e28ecb2aecd67246e189e25f5a6e6c /scripts/py/gen_model_cpp.py | |
parent | a11976fb866f77305708f832e603b963969e6a14 (diff) | |
download | ml-embedded-evaluation-kit-daba3cf2e3633cbd0e4f8aabe7578b97e88deee1.tar.gz |
MLECO-3995: Pylint + Shellcheck compatibility
* All Python scripts updated to abide by Pylint rules
* good-names updated to permit short variable names:
i, j, k, f, g, ex
* ignore-long-lines regex updated to allow long lines
for licence headers
* Shell scripts now compliant with Shellcheck
Signed-off-by: Alex Tawse <Alex.Tawse@arm.com>
Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5
Diffstat (limited to 'scripts/py/gen_model_cpp.py')
-rw-r--r-- | scripts/py/gen_model_cpp.py | 89 |
1 files changed, 64 insertions, 25 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py index e4933b5..933c189 100644 --- a/scripts/py/gen_model_cpp.py +++ b/scripts/py/gen_model_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the project directly. This should be called as part of cmake framework should the models need to be generated at configuration stage. """ -import datetime +import binascii from argparse import ArgumentParser from pathlib import Path from jinja2 import Environment, FileSystemLoader -import binascii +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True) -parser.add_argument("--output_dir", help="Output directory", required=True) -parser.add_argument('-e', '--expression', action='append', default=[], dest="expr") -parser.add_argument('--header', action='append', default=[], dest="headers") -parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces") -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -args = parser.parse_args() +parser.add_argument( + "--tflite_path", + help="Model (.tflite) path", + required=True +) + +parser.add_argument( + "--output_dir", + help="Output directory", + required=True +) + +parser.add_argument( + '-e', + '--expression', + action='append', + default=[], + dest="expr" +) + +parser.add_argument( + '--header', + action='append', + default=[], + dest="headers" +) + +parser.add_argument( + '-ns', + '--namespaces', + action='append', + default=[], + dest="namespaces" +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) +# pylint: enable=duplicate-code def get_tflite_data(tflite_path: str) -> list: """ Reads a binary file and returns a C style array as a @@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list: for i in range(0, len(hexstream), 2): if 0 == (i % hex_digits_per_line): hexstring += "\n" - hexstring += '0x' + hexstream[i:i+2] + ", " + hexstring += '0x' + hexstream[i:i + 2] + ", " hexstring += '};\n' return [hexstring] def main(args): + """ + Generate models .cpp + @param args: Parsed args + """ if not Path(args.tflite_path).is_file(): - raise Exception(f"{args.tflite_path} not found") + raise ValueError(f"{args.tflite_path} not found") # Cpp filename: cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve() @@ -80,19 +122,16 @@ def main(args): cpp_filename.parent.mkdir(exist_ok=True) - header_template = env.get_template(args.license_template) - - hdr = header_template.render(script_name=Path(__file__).name, - file_name=Path(args.tflite_path).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name) - env.get_template('tflite.cc.template').stream(common_template_header=hdr, - model_data=get_tflite_data(args.tflite_path), - expressions=args.expr, - additional_headers=args.headers, - namespaces=args.namespaces).dump(str(cpp_filename)) + env \ + .get_template('tflite.cc.template') \ + .stream(common_template_header=hdr, + model_data=get_tflite_data(args.tflite_path), + expressions=args.expr, + additional_headers=args.headers, + namespaces=args.namespaces).dump(str(cpp_filename)) if __name__ == '__main__': - main(args) + main(parsed_args) |