summaryrefslogtreecommitdiff
path: root/scripts/py
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2021-06-25 09:55:14 +0100
committerIsabella Gottardi <isabella.gottardi@arm.com>2021-07-01 18:15:26 +0000
commit1da52aeb4468ad97c09e383400bbabc8c3c77227 (patch)
treed343f44adeee5beea565e1dedfd5be74bf35b039 /scripts/py
parentd475f09826e838c2563ea634dafcbdff901c61b8 (diff)
downloadml-embedded-evaluation-kit-1da52aeb4468ad97c09e383400bbabc8c3c77227.tar.gz
MLECO-718: Minor imporovements to source gen
Changes: * minor speed up for tflite to C++ file generation * virtual env's pip upgrade before installation of packages Change-Id: If8cef85779b7381f444f608b565da0b8f994d364
Diffstat (limited to 'scripts/py')
-rw-r--r--scripts/py/gen_model_cpp.py49
1 files changed, 25 insertions, 24 deletions
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index 4843668..c43c93a 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -23,6 +23,7 @@ import os
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+import binascii
parser = ArgumentParser()
@@ -40,32 +41,32 @@ env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__)
lstrip_blocks=True)
-def write_tflite_data(tflite_path):
- # Extract array elements
+def get_tflite_data(tflite_path: str) -> list:
+ """
+ Reads a binary file and returns a C style array as a
+ list of strings.
- bytes = model_hex_bytes(tflite_path)
- line = '{\n'
- i = 1
- while True:
- try:
- el = next(bytes)
- line = line + el + ', '
- if i % 20 == 0:
- yield line
- line = ''
- i += 1
- except StopIteration:
- line = line[:-2] + '};\n'
- yield line
- break
+ Argument:
+ tflite_path: path to the tflite model.
-
-def model_hex_bytes(tflite_path):
+ Returns:
+ list of strings
+ """
with open(tflite_path, 'rb') as tflite_model:
- byte = tflite_model.read(1)
- while byte != b"":
- yield f'0x{byte.hex()}'
- byte = tflite_model.read(1)
+ data = tflite_model.read()
+
+ bytes_per_line = 32
+ hex_digits_per_line = bytes_per_line * 2
+ hexstream = binascii.hexlify(data).decode('utf-8')
+ hexstring = '{'
+
+ for i in range(0, len(hexstream), 2):
+ if 0 == (i % hex_digits_per_line):
+ hexstring += "\n"
+ hexstring += '0x' + hexstream[i:i+2] + ", "
+
+ hexstring += '};\n'
+ return [hexstring]
def main(args):
@@ -87,7 +88,7 @@ def main(args):
year=datetime.datetime.now().year)
env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=write_tflite_data(args.tflite_path),
+ model_data=get_tflite_data(args.tflite_path),
expressions=args.expr,
additional_headers=args.headers,
namespaces=args.namespaces).dump(str(cpp_filename))