diff options
Diffstat (limited to 'tests/test_cli_command_validators.py')
-rw-r--r-- | tests/test_cli_command_validators.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py new file mode 100644 index 0000000..13514a5 --- /dev/null +++ b/tests/test_cli_command_validators.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cli.command_validators module.""" +from __future__ import annotations + +import argparse +from unittest.mock import MagicMock + +import pytest + +from mlia.cli.command_validators import validate_backend +from mlia.cli.command_validators import validate_check_target_profile + + +@pytest.mark.parametrize( + "target_profile, category, expected_warnings, sys_exits", + [ + ["ethos-u55-256", {"compatibility", "performance"}, [], False], + ["ethos-u55-256", {"compatibility"}, [], False], + ["ethos-u55-256", {"performance"}, [], False], + [ + "tosa", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa." + ) + ], + False, + ], + [ + "tosa", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa. No operation was performed." + ) + ], + True, + ], + ["tosa", "compatibility", [], False], + [ + "cortex-a", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a. " + "No operation was performed." + ) + ], + True, + ], + [ + "cortex-a", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a." + ) + ], + False, + ], + ["cortex-a", "compatibility", [], False], + ], +) +def test_validate_check_target_profile( + caplog: pytest.LogCaptureFixture, + target_profile: str, + category: set[str], + expected_warnings: list[str], + sys_exits: bool, +) -> None: + """Test outcomes of category dependent target profile validation.""" + # Capture if program terminates + if sys_exits: + with pytest.raises(SystemExit) as sys_ex: + validate_check_target_profile(target_profile, category) + assert sys_ex.value.code == 0 + return + + validate_check_target_profile(target_profile, category) + + log_records = caplog.records + # Get all log records with level 30 (warning level) + warning_messages = {x.message for x in log_records if x.levelno == 30} + # Ensure the warnings coincide with the expected ones + assert warning_messages == set(expected_warnings) + + +@pytest.mark.parametrize( + "input_target_profile, input_backends, throws_exception," + "exception_message, output_backends", + [ + [ + "tosa", + ["Vela"], + True, + "Vela backend not supported with target-profile tosa.", + None, + ], + [ + "tosa", + ["Corstone-300, Vela"], + True, + "Corstone-300, Vela backend not supported with target-profile tosa.", + None, + ], + [ + "cortex-a", + ["Corstone-310", "tosa-checker"], + True, + "Corstone-310, tosa-checker backend not supported " + "with target-profile cortex-a.", + None, + ], + [ + "ethos-u55-256", + ["tosa-checker", "Corstone-310"], + True, + "tosa-checker backend not supported with target-profile ethos-u55-256.", + None, + ], + ["tosa", None, False, None, ["tosa-checker"]], + ["cortex-a", None, False, None, ["armnn-tflitedelegate"]], + ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]], + ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]], + [ + "ethos-u55-256", + ["Vela", "Corstone-300"], + False, + None, + ["Vela", "Corstone-300"], + ], + [ + "ethos-u55-256", + None, + False, + None, + ["Vela", "Corstone-300"], + ], + ], +) +def test_validate_backend( + monkeypatch: pytest.MonkeyPatch, + input_target_profile: str, + input_backends: list[str] | None, + throws_exception: bool, + exception_message: str, + output_backends: list[str] | None, +) -> None: + """Test backend validation with target-profiles and backends.""" + monkeypatch.setattr( + "mlia.cli.config.get_available_backends", + MagicMock(return_value=["Vela", "Corstone-300"]), + ) + + if throws_exception: + with pytest.raises(argparse.ArgumentError) as err: + validate_backend(input_target_profile, input_backends) + assert str(err.value.message) == exception_message + return + + assert validate_backend(input_target_profile, input_backends) == output_backends |