aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/backend/tosa_checker/compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/backend/tosa_checker/compat.py')
-rw-r--r--src/mlia/backend/tosa_checker/compat.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py
new file mode 100644
index 0000000..e1bcb24
--- /dev/null
+++ b/src/mlia/backend/tosa_checker/compat.py
@@ -0,0 +1,69 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA compatibility module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+from typing import cast
+from typing import Protocol
+
+from mlia.core.typing import PathOrFileLike
+
+
+class TOSAChecker(Protocol):
+ """TOSA checker protocol."""
+
+ def is_tosa_compatible(self) -> bool:
+ """Return true if model is TOSA compatible."""
+
+ def _get_tosa_compatibility_for_ops(self) -> list[Any]:
+ """Return list of operators."""
+
+
+@dataclass
+class Operator:
+ """Operator's TOSA compatibility info."""
+
+ location: str
+ name: str
+ is_tosa_compatible: bool
+
+
+@dataclass
+class TOSACompatibilityInfo:
+ """Models' TOSA compatibility information."""
+
+ tosa_compatible: bool
+ operators: list[Operator]
+
+
+def get_tosa_compatibility_info(
+ tflite_model_path: PathOrFileLike,
+) -> TOSACompatibilityInfo:
+ """Return list of the operators."""
+ checker = get_tosa_checker(tflite_model_path)
+
+ if checker is None:
+ raise Exception(
+ "TOSA checker is not available. "
+ "Please make sure that 'tosa-checker' backend is installed."
+ )
+
+ ops = [
+ Operator(item.location, item.name, item.is_tosa_compatible)
+ for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access
+ ]
+
+ return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
+
+
+def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
+ """Return instance of the TOSA checker."""
+ try:
+ import tosa_checker as tc # pylint: disable=import-outside-toplevel
+ except ImportError:
+ return None
+
+ checker = tc.TOSAChecker(str(tflite_model_path))
+ return cast(TOSAChecker, checker)