blob: 7c700b123b86a0f0115180a93d12a7f299deefb3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Check if a TFLite model file is Vela-optimised."""
import struct
from pathlib import Path
from ethosu.vela.tflite.Model import Model
from aiet.cli.common import InvalidTFLiteFileError
from aiet.cli.common import ModelOptimisedException
from aiet.utils.fs import read_file_as_bytearray
def get_model_from_file(input_model_file: Path) -> Model:
"""Generate Model instance from TFLite file using flatc generated code."""
buffer = read_file_as_bytearray(input_model_file)
try:
model = Model.GetRootAsModel(buffer, 0)
except (TypeError, RuntimeError, struct.error) as tflite_error:
raise InvalidTFLiteFileError(
f"Error reading in model from {input_model_file}."
) from tflite_error
return model
def is_vela_optimised(tflite_model: Model) -> bool:
"""Return True if 'ethos-u' custom operator found in the Model."""
operators = get_operators_from_model(tflite_model)
custom_codes = get_custom_codes_from_operators(operators)
return check_custom_codes_for_ethosu(custom_codes)
def get_operators_from_model(tflite_model: Model) -> list:
"""Return list of the unique operator codes used in the Model."""
return [
tflite_model.OperatorCodes(index)
for index in range(tflite_model.OperatorCodesLength())
]
def get_custom_codes_from_operators(operators: list) -> list:
"""Return list of each operator's CustomCode() strings, if they exist."""
return [
operator.CustomCode()
for operator in operators
if operator.CustomCode() is not None
]
def check_custom_codes_for_ethosu(custom_codes: list) -> bool:
"""Check for existence of ethos-u string in the custom codes."""
return any(
custom_code_name.decode("utf-8") == "ethos-u"
for custom_code_name in custom_codes
)
def check_model(tflite_file_name: str) -> None:
"""Raise an exception if model in given file is Vela optimised."""
tflite_path = Path(tflite_file_name)
tflite_model = get_model_from_file(tflite_path)
if is_vela_optimised(tflite_model):
raise ModelOptimisedException(
f"TFLite model in {tflite_file_name} is already "
f"vela optimised ('ethos-u' custom op detected)."
)
print(
f"TFLite model in {tflite_file_name} is not vela optimised "
f"('ethos-u' custom op not detected)."
)
|