aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_command_validators.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_command_validators.py')
-rw-r--r--tests/test_cli_command_validators.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py
new file mode 100644
index 0000000..13514a5
--- /dev/null
+++ b/tests/test_cli_command_validators.py
@@ -0,0 +1,167 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.command_validators module."""
+from __future__ import annotations
+
+import argparse
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.command_validators import validate_backend
+from mlia.cli.command_validators import validate_check_target_profile
+
+
+@pytest.mark.parametrize(
+ "target_profile, category, expected_warnings, sys_exits",
+ [
+ ["ethos-u55-256", {"compatibility", "performance"}, [], False],
+ ["ethos-u55-256", {"compatibility"}, [], False],
+ ["ethos-u55-256", {"performance"}, [], False],
+ [
+ "tosa",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa."
+ )
+ ],
+ False,
+ ],
+ [
+ "tosa",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa. No operation was performed."
+ )
+ ],
+ True,
+ ],
+ ["tosa", "compatibility", [], False],
+ [
+ "cortex-a",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a. "
+ "No operation was performed."
+ )
+ ],
+ True,
+ ],
+ [
+ "cortex-a",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a."
+ )
+ ],
+ False,
+ ],
+ ["cortex-a", "compatibility", [], False],
+ ],
+)
+def test_validate_check_target_profile(
+ caplog: pytest.LogCaptureFixture,
+ target_profile: str,
+ category: set[str],
+ expected_warnings: list[str],
+ sys_exits: bool,
+) -> None:
+ """Test outcomes of category dependent target profile validation."""
+ # Capture if program terminates
+ if sys_exits:
+ with pytest.raises(SystemExit) as sys_ex:
+ validate_check_target_profile(target_profile, category)
+ assert sys_ex.value.code == 0
+ return
+
+ validate_check_target_profile(target_profile, category)
+
+ log_records = caplog.records
+ # Get all log records with level 30 (warning level)
+ warning_messages = {x.message for x in log_records if x.levelno == 30}
+ # Ensure the warnings coincide with the expected ones
+ assert warning_messages == set(expected_warnings)
+
+
+@pytest.mark.parametrize(
+ "input_target_profile, input_backends, throws_exception,"
+ "exception_message, output_backends",
+ [
+ [
+ "tosa",
+ ["Vela"],
+ True,
+ "Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "tosa",
+ ["Corstone-300, Vela"],
+ True,
+ "Corstone-300, Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "cortex-a",
+ ["Corstone-310", "tosa-checker"],
+ True,
+ "Corstone-310, tosa-checker backend not supported "
+ "with target-profile cortex-a.",
+ None,
+ ],
+ [
+ "ethos-u55-256",
+ ["tosa-checker", "Corstone-310"],
+ True,
+ "tosa-checker backend not supported with target-profile ethos-u55-256.",
+ None,
+ ],
+ ["tosa", None, False, None, ["tosa-checker"]],
+ ["cortex-a", None, False, None, ["armnn-tflitedelegate"]],
+ ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]],
+ ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]],
+ [
+ "ethos-u55-256",
+ ["Vela", "Corstone-300"],
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ [
+ "ethos-u55-256",
+ None,
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ ],
+)
+def test_validate_backend(
+ monkeypatch: pytest.MonkeyPatch,
+ input_target_profile: str,
+ input_backends: list[str] | None,
+ throws_exception: bool,
+ exception_message: str,
+ output_backends: list[str] | None,
+) -> None:
+ """Test backend validation with target-profiles and backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=["Vela", "Corstone-300"]),
+ )
+
+ if throws_exception:
+ with pytest.raises(argparse.ArgumentError) as err:
+ validate_backend(input_target_profile, input_backends)
+ assert str(err.value.message) == exception_message
+ return
+
+ assert validate_backend(input_target_profile, input_backends) == output_backends