diff options
Diffstat (limited to 'scripts/py/gen_model_cpp.py')
-rw-r--r-- | scripts/py/gen_model_cpp.py | 89 |
1 files changed, 64 insertions, 25 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py index e4933b5..933c189 100644 --- a/scripts/py/gen_model_cpp.py +++ b/scripts/py/gen_model_cpp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the project directly. This should be called as part of cmake framework should the models need to be generated at configuration stage. """ -import datetime +import binascii from argparse import ArgumentParser from pathlib import Path from jinja2 import Environment, FileSystemLoader -import binascii +from gen_utils import GenUtils + +# pylint: disable=duplicate-code parser = ArgumentParser() -parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True) -parser.add_argument("--output_dir", help="Output directory", required=True) -parser.add_argument('-e', '--expression', action='append', default=[], dest="expr") -parser.add_argument('--header', action='append', default=[], dest="headers") -parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces") -parser.add_argument("--license_template", type=str, help="Header template file", - default="header_template.txt") -args = parser.parse_args() +parser.add_argument( + "--tflite_path", + help="Model (.tflite) path", + required=True +) + +parser.add_argument( + "--output_dir", + help="Output directory", + required=True +) + +parser.add_argument( + '-e', + '--expression', + action='append', + default=[], + dest="expr" +) + +parser.add_argument( + '--header', + action='append', + default=[], + dest="headers" +) + +parser.add_argument( + '-ns', + '--namespaces', + action='append', + default=[], + dest="namespaces" +) + +parser.add_argument( + "--license_template", + type=str, + help="Header template file", + default="header_template.txt" +) + +parsed_args = parser.parse_args() env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), trim_blocks=True, lstrip_blocks=True) +# pylint: enable=duplicate-code def get_tflite_data(tflite_path: str) -> list: """ Reads a binary file and returns a C style array as a @@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list: for i in range(0, len(hexstream), 2): if 0 == (i % hex_digits_per_line): hexstring += "\n" - hexstring += '0x' + hexstream[i:i+2] + ", " + hexstring += '0x' + hexstream[i:i + 2] + ", " hexstring += '};\n' return [hexstring] def main(args): + """ + Generate models .cpp + @param args: Parsed args + """ if not Path(args.tflite_path).is_file(): - raise Exception(f"{args.tflite_path} not found") + raise ValueError(f"{args.tflite_path} not found") # Cpp filename: cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve() @@ -80,19 +122,16 @@ def main(args): cpp_filename.parent.mkdir(exist_ok=True) - header_template = env.get_template(args.license_template) - - hdr = header_template.render(script_name=Path(__file__).name, - file_name=Path(args.tflite_path).name, - gen_time=datetime.datetime.now(), - year=datetime.datetime.now().year) + hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name) - env.get_template('tflite.cc.template').stream(common_template_header=hdr, - model_data=get_tflite_data(args.tflite_path), - expressions=args.expr, - additional_headers=args.headers, - namespaces=args.namespaces).dump(str(cpp_filename)) + env \ + .get_template('tflite.cc.template') \ + .stream(common_template_header=hdr, + model_data=get_tflite_data(args.tflite_path), + expressions=args.expr, + additional_headers=args.headers, + namespaces=args.namespaces).dump(str(cpp_filename)) if __name__ == '__main__': - main(args) + main(parsed_args) |