From 1da52aeb4468ad97c09e383400bbabc8c3c77227 Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Fri, 25 Jun 2021 09:55:14 +0100 Subject: MLECO-718: Minor imporovements to source gen Changes: * minor speed up for tflite to C++ file generation * virtual env's pip upgrade before installation of packages Change-Id: If8cef85779b7381f444f608b565da0b8f994d364 --- scripts/cmake/source_gen_utils.cmake | 22 +++++++++++++--- scripts/py/gen_model_cpp.py | 49 ++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/scripts/cmake/source_gen_utils.cmake b/scripts/cmake/source_gen_utils.cmake index 92ec53d..9d27b4d 100644 --- a/scripts/cmake/source_gen_utils.cmake +++ b/scripts/cmake/source_gen_utils.cmake @@ -272,22 +272,38 @@ function(setup_source_generator) message(STATUS "Using existing python at ${PYTHON}") return() endif () + message(STATUS "Configuring python environment at ${PYTHON}") + execute_process( COMMAND ${PY_EXEC} -m venv ${CMAKE_BINARY_DIR}/pyenv RESULT_VARIABLE return_code ) + if (NOT return_code STREQUAL "0") + message(FATAL_ERROR "Failed to setup python3 environment. Return code: ${return_code}") + endif () + + execute_process( + COMMAND ${PYTHON} -m pip install --upgrade pip + RESULT_VARIABLE return_code + ) if (NOT return_code EQUAL "0") - message(FATAL_ERROR "Failed to setup python3 environment") + message(FATAL_ERROR "Failed to upgrade pip") endif () - execute_process(COMMAND ${PYTHON} -m pip install wheel) + execute_process( + COMMAND ${PYTHON} -m pip install wheel + RESULT_VARIABLE return_code + ) + if (NOT return_code EQUAL "0") + message(FATAL_ERROR "Failed to install wheel") + endif () execute_process( COMMAND ${PYTHON} -m pip install -r ${SCRIPTS_DIR}/py/requirements.txt RESULT_VARIABLE return_code ) if (NOT return_code EQUAL "0") - message(FATAL_ERROR "Failed to setup python3 environment") + message(FATAL_ERROR "Failed to install requirements") endif () endfunction() diff --git a/scripts/py/gen_model_cpp.py b/scripts/py/gen_model_cpp.py index 4843668..c43c93a 100644 --- a/scripts/py/gen_model_cpp.py +++ b/scripts/py/gen_model_cpp.py @@ -23,6 +23,7 @@ import os from argparse import ArgumentParser from pathlib import Path from jinja2 import Environment, FileSystemLoader +import binascii parser = ArgumentParser() @@ -40,32 +41,32 @@ env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__) lstrip_blocks=True) -def write_tflite_data(tflite_path): - # Extract array elements +def get_tflite_data(tflite_path: str) -> list: + """ + Reads a binary file and returns a C style array as a + list of strings. - bytes = model_hex_bytes(tflite_path) - line = '{\n' - i = 1 - while True: - try: - el = next(bytes) - line = line + el + ', ' - if i % 20 == 0: - yield line - line = '' - i += 1 - except StopIteration: - line = line[:-2] + '};\n' - yield line - break + Argument: + tflite_path: path to the tflite model. - -def model_hex_bytes(tflite_path): + Returns: + list of strings + """ with open(tflite_path, 'rb') as tflite_model: - byte = tflite_model.read(1) - while byte != b"": - yield f'0x{byte.hex()}' - byte = tflite_model.read(1) + data = tflite_model.read() + + bytes_per_line = 32 + hex_digits_per_line = bytes_per_line * 2 + hexstream = binascii.hexlify(data).decode('utf-8') + hexstring = '{' + + for i in range(0, len(hexstream), 2): + if 0 == (i % hex_digits_per_line): + hexstring += "\n" + hexstring += '0x' + hexstream[i:i+2] + ", " + + hexstring += '};\n' + return [hexstring] def main(args): @@ -87,7 +88,7 @@ def main(args): year=datetime.datetime.now().year) env.get_template('tflite.cc.template').stream(common_template_header=hdr, - model_data=write_tflite_data(args.tflite_path), + model_data=get_tflite_data(args.tflite_path), expressions=args.expr, additional_headers=args.headers, namespaces=args.namespaces).dump(str(cpp_filename)) -- cgit v1.2.1