diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-07-21 14:06:50 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-08-19 10:23:23 +0100 |
commit | 664d8c55609253e68d153a91514c8fefa00557b1 (patch) | |
tree | 4b2a0ecaf30e9151d6b971a24fa6c6104884896f /src/mlia/devices/tosa/operators.py | |
parent | a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (diff) | |
download | mlia-664d8c55609253e68d153a91514c8fefa00557b1.tar.gz |
MLIA-549 Integrate TOSA checker into MLIA
- Add new module for TOSA
- Add advisor workflow components
- Use TOSA checker for getting operators compatibility
information
Change-Id: I769e5e2a84e15779658f0895b4a347384def63bf
Diffstat (limited to 'src/mlia/devices/tosa/operators.py')
-rw-r--r-- | src/mlia/devices/tosa/operators.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py new file mode 100644 index 0000000..4f3df10 --- /dev/null +++ b/src/mlia/devices/tosa/operators.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Operators module.""" +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Protocol + +from mlia.core._typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> List[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: List[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa_checker' package is installed: " + "pip install mlia[tosa]" + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) |