summaryrefslogtreecommitdiff
path: root/scripts/py/gen_model_cpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/py/gen_model_cpp.py')
-rw-r--r--scripts/py/gen_model_cpp.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index c43c93a..de71992 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -19,9 +19,9 @@ project directly. This should be called as part of cmake framework
should the models need to be generated at configuration stage.
"""
import datetime
-import os
from argparse import ArgumentParser
from pathlib import Path
+
from jinja2 import Environment, FileSystemLoader
import binascii
@@ -36,7 +36,7 @@ parser.add_argument("--license_template", type=str, help="Header template file",
default="header_template.txt")
args = parser.parse_args()
-env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates')),
+env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
trim_blocks=True,
lstrip_blocks=True)
@@ -70,20 +70,20 @@ def get_tflite_data(tflite_path: str) -> list:
def main(args):
- if not os.path.isfile(args.tflite_path):
+ if not Path(args.tflite_path).is_file():
raise Exception(f"{args.tflite_path} not found")
# Cpp filename:
- cpp_filename = Path(os.path.join(args.output_dir, os.path.basename(args.tflite_path) + ".cc")).absolute()
- print(f"++ Converting {os.path.basename(args.tflite_path)} to\
- {os.path.basename(cpp_filename)}")
+ cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
+ print(f"++ Converting {Path(args.tflite_path).name} to\
+ {cpp_filename.name}")
- os.makedirs(cpp_filename.parent, exist_ok=True)
+ cpp_filename.parent.mkdir(exist_ok=True)
header_template = env.get_template(args.license_template)
- hdr = header_template.render(script_name=os.path.basename(__file__),
- file_name=os.path.basename(args.tflite_path),
+ hdr = header_template.render(script_name=Path(__file__).name,
+ file_name=Path(args.tflite_path).name,
gen_time=datetime.datetime.now(),
year=datetime.datetime.now().year)