summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2021-06-25 09:55:14 +0100
committerIsabella Gottardi <isabella.gottardi@arm.com>2021-07-01 18:15:26 +0000
commit1da52aeb4468ad97c09e383400bbabc8c3c77227 (patch)
treed343f44adeee5beea565e1dedfd5be74bf35b039
parentd475f09826e838c2563ea634dafcbdff901c61b8 (diff)
downloadml-embedded-evaluation-kit-1da52aeb4468ad97c09e383400bbabc8c3c77227.tar.gz
MLECO-718: Minor imporovements to source gen
Changes: * minor speed up for tflite to C++ file generation * virtual env's pip upgrade before installation of packages Change-Id: If8cef85779b7381f444f608b565da0b8f994d364
-rw-r--r--scripts/cmake/source_gen_utils.cmake22
-rw-r--r--scripts/py/gen_model_cpp.py49
2 files changed, 44 insertions, 27 deletions
diff --git a/scripts/cmake/source_gen_utils.cmake b/scripts/cmake/source_gen_utils.cmake
index 92ec53d..9d27b4d 100644
--- a/scripts/cmake/source_gen_utils.cmake
+++ b/scripts/cmake/source_gen_utils.cmake
@@ -272,22 +272,38 @@ function(setup_source_generator)
message(STATUS "Using existing python at ${PYTHON}")
return()
endif ()
+
message(STATUS "Configuring python environment at ${PYTHON}")
+
execute_process(
COMMAND ${PY_EXEC} -m venv ${CMAKE_BINARY_DIR}/pyenv
RESULT_VARIABLE return_code
)
+ if (NOT return_code STREQUAL "0")
+ message(FATAL_ERROR "Failed to setup python3 environment. Return code: ${return_code}")
+ endif ()
+
+ execute_process(
+ COMMAND ${PYTHON} -m pip install --upgrade pip
+ RESULT_VARIABLE return_code
+ )
if (NOT return_code EQUAL "0")
- message(FATAL_ERROR "Failed to setup python3 environment")
+ message(FATAL_ERROR "Failed to upgrade pip")
endif ()
- execute_process(COMMAND ${PYTHON} -m pip install wheel)
+ execute_process(
+ COMMAND ${PYTHON} -m pip install wheel
+ RESULT_VARIABLE return_code
+ )
+ if (NOT return_code EQUAL "0")
+ message(FATAL_ERROR "Failed to install wheel")
+ endif ()
execute_process(
COMMAND ${PYTHON} -m pip install -r ${SCRIPTS_DIR}/py/requirements.txt
RESULT_VARIABLE return_code
)
if (NOT return_code EQUAL "0")
- message(FATAL_ERROR "Failed to setup python3 environment")
+ message(FATAL_ERROR "Failed to install requirements")
endif ()
endfunction()
diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py
index 4843668..c43c93a 100644
--- a/scripts/py/gen_model_cpp.py
+++ b/scripts/py/gen_model_cpp.py
@@ -23,6 +23,7 @@ import os
from argparse import ArgumentParser
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
+import binascii
parser = ArgumentParser()
@@ -40,32 +41,32 @@ env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__)
lstrip_blocks=True)
-def write_tflite_data(tflite_path):
- # Extract array elements
+def get_tflite_data(tflite_path: str) -> list:
+ """
+ Reads a binary file and returns a C style array as a
+ list of strings.
- bytes = model_hex_bytes(tflite_path)
- line = '{\n'
- i = 1
- while True:
- try:
- el = next(bytes)
- line = line + el + ', '
- if i % 20 == 0:
- yield line
- line = ''
- i += 1
- except StopIteration:
- line = line[:-2] + '};\n'
- yield line
- break
+ Argument:
+ tflite_path: path to the tflite model.
-
-def model_hex_bytes(tflite_path):
+ Returns:
+ list of strings
+ """
with open(tflite_path, 'rb') as tflite_model:
- byte = tflite_model.read(1)
- while byte != b"":
- yield f'0x{byte.hex()}'
- byte = tflite_model.read(1)
+ data = tflite_model.read()
+
+ bytes_per_line = 32
+ hex_digits_per_line = bytes_per_line * 2
+ hexstream = binascii.hexlify(data).decode('utf-8')
+ hexstring = '{'
+
+ for i in range(0, len(hexstream), 2):
+ if 0 == (i % hex_digits_per_line):
+ hexstring += "\n"
+ hexstring += '0x' + hexstream[i:i+2] + ", "
+
+ hexstring += '};\n'
+ return [hexstring]
def main(args):
@@ -87,7 +88,7 @@ def main(args):
year=datetime.datetime.now().year)
env.get_template('tflite.cc.template').stream(common_template_header=hdr,
- model_data=write_tflite_data(args.tflite_path),
+ model_data=get_tflite_data(args.tflite_path),
expressions=args.expr,
additional_headers=args.headers,
namespaces=args.namespaces).dump(str(cpp_filename))