diff options
Diffstat (limited to 'tests/test_api.py')
-rw-r--r-- | tests/test_api.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/test_api.py b/tests/test_api.py index 7b567bf..6fa15b3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the API functions.""" +from __future__ import annotations + from pathlib import Path from unittest.mock import MagicMock +from unittest.mock import patch import pytest +from mlia.api import generate_supported_operators_report from mlia.api import get_advice from mlia.api import get_advisor from mlia.core.common import AdviceCategory @@ -107,3 +111,50 @@ def test_get_advisor( tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) assert isinstance(tosa_advisor, TOSAInferenceAdvisor) + + +@pytest.mark.parametrize( + ["target_profile", "required_calls", "exception_msg"], + [ + [ + "ethos-u55-128", + "mlia.tools.vela_wrapper.generate_supported_operators_report", + None, + ], + [ + "ethos-u65-256", + "mlia.tools.vela_wrapper.generate_supported_operators_report", + None, + ], + [ + "tosa", + None, + "Generating a supported operators report is not " + "currently supported with TOSA target profile.", + ], + [ + "cortex-a", + None, + "Generating a supported operators report is not " + "currently supported with Cortex-A target profile.", + ], + [ + "Unknown", + None, + "Unable to find target profile Unknown", + ], + ], +) +def test_supported_ops_report_generator( + target_profile: str, required_calls: str | None, exception_msg: str | None +) -> None: + """Test supported operators report generator with different target profiles.""" + if exception_msg: + with pytest.raises(Exception) as exc: + generate_supported_operators_report(target_profile) + assert str(exc.value) == exception_msg + + if required_calls: + with patch(required_calls) as mock_method: + generate_supported_operators_report(target_profile) + mock_method.assert_called_once() |