summaryrefslogtreecommitdiff
path: root/scripts/py/gen_model_cpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/py/gen_model_cpp.py')
-rw-r--r--scripts/py/gen_model_cpp.py89
1 files changed, 64 insertions, 25 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index e4933b5..933c189 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,29 +18,67 @@ Utility script to generate model c file that can be included in the
project directly. This should be called as part of cmake framework
should the models need to be generated at configuration stage.
"""
-import datetime
+import binascii
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
-import binascii
+from gen_utils import GenUtils
+
+# pylint: disable=duplicate-code
parser = ArgumentParser()
-parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
-parser.add_argument("--output_dir", help="Output directory", required=True)
-parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
-parser.add_argument('--header', action='append', default=[], dest="headers")
-parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
-parser.add_argument("--license_template", type=str, help="Header template file",
- default="header_template.txt")
-args = parser.parse_args()
+parser.add_argument(
+ "--tflite_path",
+ help="Model (.tflite) path",
+ required=True
+)
+
+parser.add_argument(
+ "--output_dir",
+ help="Output directory",
+ required=True
+)
+
+parser.add_argument(
+ '-e',
+ '--expression',
+ action='append',
+ default=[],
+ dest="expr"
+)
+
+parser.add_argument(
+ '--header',
+ action='append',
+ default=[],
+ dest="headers"
+)
+
+parser.add_argument(
+ '-ns',
+ '--namespaces',
+ action='append',
+ default=[],
+ dest="namespaces"
+)
+
+parser.add_argument(
+ "--license_template",
+ type=str,
+ help="Header template file",
+ default="header_template.txt"
+)
+
+parsed_args = parser.parse_args()
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
+# pylint: enable=duplicate-code
def get_tflite_data(tflite_path: str) -> list:
"""
Reads a binary file and returns a C style array as a
@@ -63,15 +101,19 @@ def get_tflite_data(tflite_path: str) -> list:
for i in range(0, len(hexstream), 2):
if 0 == (i % hex_digits_per_line):
hexstring += "\n"
- hexstring += '0x' + hexstream[i:i+2] + ", "
+ hexstring += '0x' + hexstream[i:i + 2] + ", "
hexstring += '};\n'
return [hexstring]
def main(args):
+ """
+ Generate models .cpp
+ @param args: Parsed args
+ """
if not Path(args.tflite_path).is_file():
- raise Exception(f"{args.tflite_path} not found")
+ raise ValueError(f"{args.tflite_path} not found")
# Cpp filename:
cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
@@ -80,19 +122,16 @@ def main(args):
cpp_filename.parent.mkdir(exist_ok=True)
- header_template = env.get_template(args.license_template)
-
- hdr = header_template.render(script_name=Path(__file__).name,
- file_name=Path(args.tflite_path).name,
- gen_time=datetime.datetime.now(),
- year=datetime.datetime.now().year)
+ hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name)
- env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=get_tflite_data(args.tflite_path),
- expressions=args.expr,
- additional_headers=args.headers,
- namespaces=args.namespaces).dump(str(cpp_filename))
+ env \
+ .get_template('tflite.cc.template') \
+ .stream(common_template_header=hdr,
+ model_data=get_tflite_data(args.tflite_path),
+ expressions=args.expr,
+ additional_headers=args.headers,
+ namespaces=args.namespaces).dump(str(cpp_filename))
if __name__ == '__main__':
- main(args)
+ main(parsed_args)