summaryrefslogtreecommitdiff
path: root/scripts/py/gen_model_cpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/py/gen_model_cpp.py')
-rw-r--r--scripts/py/gen_model_cpp.py49
1 files changed, 25 insertions, 24 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index 4843668..c43c93a 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -23,6 +23,7 @@ import os
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+import binascii
parser = ArgumentParser()
@@ -40,32 +41,32 @@ env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__)
lstrip_blocks=True)
-def write_tflite_data(tflite_path):
- # Extract array elements
+def get_tflite_data(tflite_path: str) -> list:
+ """
+ Reads a binary file and returns a C style array as a
+ list of strings.
- bytes = model_hex_bytes(tflite_path)
- line = '{\n'
- i = 1
- while True:
- try:
- el = next(bytes)
- line = line + el + ', '
- if i % 20 == 0:
- yield line
- line = ''
- i += 1
- except StopIteration:
- line = line[:-2] + '};\n'
- yield line
- break
+ Argument:
+ tflite_path: path to the tflite model.
-
-def model_hex_bytes(tflite_path):
+ Returns:
+ list of strings
+ """
with open(tflite_path, 'rb') as tflite_model:
- byte = tflite_model.read(1)
- while byte != b"":
- yield f'0x{byte.hex()}'
- byte = tflite_model.read(1)
+ data = tflite_model.read()
+
+ bytes_per_line = 32
+ hex_digits_per_line = bytes_per_line * 2
+ hexstream = binascii.hexlify(data).decode('utf-8')
+ hexstring = '{'
+
+ for i in range(0, len(hexstream), 2):
+ if 0 == (i % hex_digits_per_line):
+ hexstring += "\n"
+ hexstring += '0x' + hexstream[i:i+2] + ", "
+
+ hexstring += '};\n'
+ return [hexstring]
def main(args):
@@ -87,7 +88,7 @@ def main(args):
year=datetime.datetime.now().year)
env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=write_tflite_data(args.tflite_path),
+ model_data=get_tflite_data(args.tflite_path),
expressions=args.expr,
additional_headers=args.headers,
namespaces=args.namespaces).dump(str(cpp_filename))