aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/resources/tools/vela/check_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/resources/tools/vela/check_model.py')
-rw-r--r--src/aiet/resources/tools/vela/check_model.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py
new file mode 100644
index 0000000..7c700b1
--- /dev/null
+++ b/src/aiet/resources/tools/vela/check_model.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Check if a TFLite model file is Vela-optimised."""
+import struct
+from pathlib import Path
+
+from ethosu.vela.tflite.Model import Model
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.utils.fs import read_file_as_bytearray
+
+
+def get_model_from_file(input_model_file: Path) -> Model:
+ """Generate Model instance from TFLite file using flatc generated code."""
+ buffer = read_file_as_bytearray(input_model_file)
+ try:
+ model = Model.GetRootAsModel(buffer, 0)
+ except (TypeError, RuntimeError, struct.error) as tflite_error:
+ raise InvalidTFLiteFileError(
+ f"Error reading in model from {input_model_file}."
+ ) from tflite_error
+ return model
+
+
+def is_vela_optimised(tflite_model: Model) -> bool:
+ """Return True if 'ethos-u' custom operator found in the Model."""
+ operators = get_operators_from_model(tflite_model)
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ return check_custom_codes_for_ethosu(custom_codes)
+
+
+def get_operators_from_model(tflite_model: Model) -> list:
+ """Return list of the unique operator codes used in the Model."""
+ return [
+ tflite_model.OperatorCodes(index)
+ for index in range(tflite_model.OperatorCodesLength())
+ ]
+
+
+def get_custom_codes_from_operators(operators: list) -> list:
+ """Return list of each operator's CustomCode() strings, if they exist."""
+ return [
+ operator.CustomCode()
+ for operator in operators
+ if operator.CustomCode() is not None
+ ]
+
+
+def check_custom_codes_for_ethosu(custom_codes: list) -> bool:
+ """Check for existence of ethos-u string in the custom codes."""
+ return any(
+ custom_code_name.decode("utf-8") == "ethos-u"
+ for custom_code_name in custom_codes
+ )
+
+
+def check_model(tflite_file_name: str) -> None:
+ """Raise an exception if model in given file is Vela optimised."""
+ tflite_path = Path(tflite_file_name)
+
+ tflite_model = get_model_from_file(tflite_path)
+
+ if is_vela_optimised(tflite_model):
+ raise ModelOptimisedException(
+ f"TFLite model in {tflite_file_name} is already "
+ f"vela optimised ('ethos-u' custom op detected)."
+ )
+
+ print(
+ f"TFLite model in {tflite_file_name} is not vela optimised "
+ f"('ethos-u' custom op not detected)."
+ )