From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/aiet/resources/tools/vela/check_model.py | 75 ++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/aiet/resources/tools/vela/check_model.py (limited to 'src/aiet/resources/tools/vela/check_model.py') diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py new file mode 100644 index 0000000..7c700b1 --- /dev/null +++ b/src/aiet/resources/tools/vela/check_model.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Check if a TFLite model file is Vela-optimised.""" +import struct +from pathlib import Path + +from ethosu.vela.tflite.Model import Model + +from aiet.cli.common import InvalidTFLiteFileError +from aiet.cli.common import ModelOptimisedException +from aiet.utils.fs import read_file_as_bytearray + + +def get_model_from_file(input_model_file: Path) -> Model: + """Generate Model instance from TFLite file using flatc generated code.""" + buffer = read_file_as_bytearray(input_model_file) + try: + model = Model.GetRootAsModel(buffer, 0) + except (TypeError, RuntimeError, struct.error) as tflite_error: + raise InvalidTFLiteFileError( + f"Error reading in model from {input_model_file}." + ) from tflite_error + return model + + +def is_vela_optimised(tflite_model: Model) -> bool: + """Return True if 'ethos-u' custom operator found in the Model.""" + operators = get_operators_from_model(tflite_model) + + custom_codes = get_custom_codes_from_operators(operators) + + return check_custom_codes_for_ethosu(custom_codes) + + +def get_operators_from_model(tflite_model: Model) -> list: + """Return list of the unique operator codes used in the Model.""" + return [ + tflite_model.OperatorCodes(index) + for index in range(tflite_model.OperatorCodesLength()) + ] + + +def get_custom_codes_from_operators(operators: list) -> list: + """Return list of each operator's CustomCode() strings, if they exist.""" + return [ + operator.CustomCode() + for operator in operators + if operator.CustomCode() is not None + ] + + +def check_custom_codes_for_ethosu(custom_codes: list) -> bool: + """Check for existence of ethos-u string in the custom codes.""" + return any( + custom_code_name.decode("utf-8") == "ethos-u" + for custom_code_name in custom_codes + ) + + +def check_model(tflite_file_name: str) -> None: + """Raise an exception if model in given file is Vela optimised.""" + tflite_path = Path(tflite_file_name) + + tflite_model = get_model_from_file(tflite_path) + + if is_vela_optimised(tflite_model): + raise ModelOptimisedException( + f"TFLite model in {tflite_file_name} is already " + f"vela optimised ('ethos-u' custom op detected)." + ) + + print( + f"TFLite model in {tflite_file_name} is not vela optimised " + f"('ethos-u' custom op not detected)." + ) -- cgit v1.2.1