aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/backend/tosa_checker/compat.py
blob: bd217747e511f1af50641a473af07457b95f166c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA compatibility module."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from typing import cast
from typing import Protocol

from mlia.backend.errors import BackendUnavailableError
from mlia.core.typing import PathOrFileLike


class TOSAChecker(Protocol):
    """TOSA checker protocol."""

    def is_tosa_compatible(self) -> bool:
        """Return true if model is TOSA compatible."""

    def _get_tosa_compatibility_for_ops(self) -> list[Any]:
        """Return list of operators."""


@dataclass
class Operator:
    """Operator's TOSA compatibility info."""

    location: str
    name: str
    is_tosa_compatible: bool


@dataclass
class TOSACompatibilityInfo:
    """Models' TOSA compatibility information."""

    tosa_compatible: bool
    operators: list[Operator]


def get_tosa_compatibility_info(
    tflite_model_path: PathOrFileLike,
) -> TOSACompatibilityInfo:
    """Return list of the operators."""
    checker = get_tosa_checker(tflite_model_path)

    if checker is None:
        raise BackendUnavailableError(
            "Backend tosa-checker is not available", "tosa-checker"
        )

    ops = [
        Operator(item.location, item.name, item.is_tosa_compatible)
        for item in checker._get_tosa_compatibility_for_ops()  # pylint: disable=protected-access
    ]

    return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)


def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
    """Return instance of the TOSA checker."""
    try:
        import tosa_checker as tc  # pylint: disable=import-outside-toplevel
    except ImportError:
        return None

    checker = tc.TOSAChecker(str(tflite_model_path))
    return cast(TOSAChecker, checker)