From a34163c9d9a5cc0416bcaea2ebf8383bda9d505c Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 24 Nov 2022 08:34:38 +0000 Subject: Move TOSA checker functions into separate module - Create module "compat" for tosa_checker backend - Move TOSA checker functions into new module - Update tests Change-Id: Ia07034515fe43b2061b8892535067d21315cc721 --- src/mlia/backend/tosa_checker/compat.py | 69 +++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 src/mlia/backend/tosa_checker/compat.py (limited to 'src/mlia/backend/tosa_checker') diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py new file mode 100644 index 0000000..e1bcb24 --- /dev/null +++ b/src/mlia/backend/tosa_checker/compat.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA compatibility module.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import Protocol + +from mlia.core.typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> list[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: list[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa-checker' backend is installed." + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) -- cgit v1.2.1