diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-24 08:34:38 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-29 14:44:13 +0000 |
commit | a34163c9d9a5cc0416bcaea2ebf8383bda9d505c (patch) | |
tree | 304c01c607b3a93c250a38df53c417f62196b5fa /tests/test_backend_tosa_compat.py | |
parent | 37959522a805a5e23c930ed79aac84920c3cb208 (diff) | |
download | mlia-a34163c9d9a5cc0416bcaea2ebf8383bda9d505c.tar.gz |
Move TOSA checker functions into separate module
- Create module "compat" for tosa_checker backend
- Move TOSA checker functions into new module
- Update tests
Change-Id: Ia07034515fe43b2061b8892535067d21315cc721
Diffstat (limited to 'tests/test_backend_tosa_compat.py')
-rw-r--r-- | tests/test_backend_tosa_compat.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py new file mode 100644 index 0000000..4c4dc5a --- /dev/null +++ b/tests/test_backend_tosa_compat.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA compatibility.""" +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info +from mlia.backend.tosa_checker.compat import Operator +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo + + +def replace_get_tosa_checker_with_mock( + monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None +) -> None: + """Replace TOSA checker with mock.""" + monkeypatch.setattr( + "mlia.backend.tosa_checker.compat.get_tosa_checker", + MagicMock(return_value=mock), + ) + + +def test_compatibility_check_should_fail_if_checker_not_available( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: + """Test that compatibility check should fail if TOSA checker is not available.""" + replace_get_tosa_checker_with_mock(monkeypatch, None) + + with pytest.raises(Exception, match="TOSA checker is not available"): + get_tosa_compatibility_info(test_tflite_model) + + +@pytest.mark.parametrize( + "is_tosa_compatible, operators, expected_result", + [ + [ + True, + [], + TOSACompatibilityInfo(True, []), + ], + [ + True, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=True, + ) + ], + TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), + ], + [ + False, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=False, + ) + ], + TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), + ], + ], +) +def test_get_tosa_compatibility_info( + monkeypatch: pytest.MonkeyPatch, + test_tflite_model: Path, + is_tosa_compatible: bool, + operators: Any, + expected_result: TOSACompatibilityInfo, +) -> None: + """Test getting TOSA compatibility information.""" + mock_checker = MagicMock() + mock_checker.is_tosa_compatible.return_value = is_tosa_compatible + mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access + operators + ) + + replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) + + assert get_tosa_compatibility_info(test_tflite_model) == expected_result |