diff options
Diffstat (limited to 'src/aiet/resources/tools/vela/check_model.py')
-rw-r--r-- | src/aiet/resources/tools/vela/check_model.py | 75 |
1 files changed, 0 insertions, 75 deletions
diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py deleted file mode 100644 index 7c700b1..0000000 --- a/src/aiet/resources/tools/vela/check_model.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Check if a TFLite model file is Vela-optimised.""" -import struct -from pathlib import Path - -from ethosu.vela.tflite.Model import Model - -from aiet.cli.common import InvalidTFLiteFileError -from aiet.cli.common import ModelOptimisedException -from aiet.utils.fs import read_file_as_bytearray - - -def get_model_from_file(input_model_file: Path) -> Model: - """Generate Model instance from TFLite file using flatc generated code.""" - buffer = read_file_as_bytearray(input_model_file) - try: - model = Model.GetRootAsModel(buffer, 0) - except (TypeError, RuntimeError, struct.error) as tflite_error: - raise InvalidTFLiteFileError( - f"Error reading in model from {input_model_file}." - ) from tflite_error - return model - - -def is_vela_optimised(tflite_model: Model) -> bool: - """Return True if 'ethos-u' custom operator found in the Model.""" - operators = get_operators_from_model(tflite_model) - - custom_codes = get_custom_codes_from_operators(operators) - - return check_custom_codes_for_ethosu(custom_codes) - - -def get_operators_from_model(tflite_model: Model) -> list: - """Return list of the unique operator codes used in the Model.""" - return [ - tflite_model.OperatorCodes(index) - for index in range(tflite_model.OperatorCodesLength()) - ] - - -def get_custom_codes_from_operators(operators: list) -> list: - """Return list of each operator's CustomCode() strings, if they exist.""" - return [ - operator.CustomCode() - for operator in operators - if operator.CustomCode() is not None - ] - - -def check_custom_codes_for_ethosu(custom_codes: list) -> bool: - """Check for existence of ethos-u string in the custom codes.""" - return any( - custom_code_name.decode("utf-8") == "ethos-u" - for custom_code_name in custom_codes - ) - - -def check_model(tflite_file_name: str) -> None: - """Raise an exception if model in given file is Vela optimised.""" - tflite_path = Path(tflite_file_name) - - tflite_model = get_model_from_file(tflite_path) - - if is_vela_optimised(tflite_model): - raise ModelOptimisedException( - f"TFLite model in {tflite_file_name} is already " - f"vela optimised ('ethos-u' custom op detected)." - ) - - print( - f"TFLite model in {tflite_file_name} is not vela optimised " - f"('ethos-u' custom op not detected)." - ) |