aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_tosa_compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_backend_tosa_compat.py')
-rw-r--r--tests/test_backend_tosa_compat.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py
new file mode 100644
index 0000000..4c4dc5a
--- /dev/null
+++ b/tests/test_backend_tosa_compat.py
@@ -0,0 +1,86 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA compatibility."""
+from __future__ import annotations
+
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
+from mlia.backend.tosa_checker.compat import Operator
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+
+
+def replace_get_tosa_checker_with_mock(
+ monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None
+) -> None:
+ """Replace TOSA checker with mock."""
+ monkeypatch.setattr(
+ "mlia.backend.tosa_checker.compat.get_tosa_checker",
+ MagicMock(return_value=mock),
+ )
+
+
+def test_compatibility_check_should_fail_if_checker_not_available(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+) -> None:
+ """Test that compatibility check should fail if TOSA checker is not available."""
+ replace_get_tosa_checker_with_mock(monkeypatch, None)
+
+ with pytest.raises(Exception, match="TOSA checker is not available"):
+ get_tosa_compatibility_info(test_tflite_model)
+
+
+@pytest.mark.parametrize(
+ "is_tosa_compatible, operators, expected_result",
+ [
+ [
+ True,
+ [],
+ TOSACompatibilityInfo(True, []),
+ ],
+ [
+ True,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=True,
+ )
+ ],
+ TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]),
+ ],
+ [
+ False,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=False,
+ )
+ ],
+ TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]),
+ ],
+ ],
+)
+def test_get_tosa_compatibility_info(
+ monkeypatch: pytest.MonkeyPatch,
+ test_tflite_model: Path,
+ is_tosa_compatible: bool,
+ operators: Any,
+ expected_result: TOSACompatibilityInfo,
+) -> None:
+ """Test getting TOSA compatibility information."""
+ mock_checker = MagicMock()
+ mock_checker.is_tosa_compatible.return_value = is_tosa_compatible
+ mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access
+ operators
+ )
+
+ replace_get_tosa_checker_with_mock(monkeypatch, mock_checker)
+
+ assert get_tosa_compatibility_info(test_tflite_model) == expected_result