diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-24 08:34:38 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-29 14:44:13 +0000 |
commit | a34163c9d9a5cc0416bcaea2ebf8383bda9d505c (patch) | |
tree | 304c01c607b3a93c250a38df53c417f62196b5fa /src/mlia/backend/tosa_checker/compat.py | |
parent | 37959522a805a5e23c930ed79aac84920c3cb208 (diff) | |
download | mlia-a34163c9d9a5cc0416bcaea2ebf8383bda9d505c.tar.gz |
Move TOSA checker functions into separate module
- Create module "compat" for tosa_checker backend
- Move TOSA checker functions into new module
- Update tests
Change-Id: Ia07034515fe43b2061b8892535067d21315cc721
Diffstat (limited to 'src/mlia/backend/tosa_checker/compat.py')
-rw-r--r-- | src/mlia/backend/tosa_checker/compat.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py new file mode 100644 index 0000000..e1bcb24 --- /dev/null +++ b/src/mlia/backend/tosa_checker/compat.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA compatibility module.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import Protocol + +from mlia.core.typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> list[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: list[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa-checker' backend is installed." + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) |