aboutsummaryrefslogtreecommitdiff
path: root/tests/test_core_context.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_core_context.py')
-rw-r--r--tests/test_core_context.py46
1 files changed, 41 insertions, 5 deletions
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index 44eb976..dcdbef3 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -1,17 +1,53 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module context."""
+from __future__ import annotations
+
from pathlib import Path
+import pytest
+
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.events import DefaultEventPublisher
+@pytest.mark.parametrize(
+ "context_advice_category, expected_enabled_categories",
+ [
+ [
+ {
+ AdviceCategory.COMPATIBILITY,
+ },
+ [AdviceCategory.COMPATIBILITY],
+ ],
+ [
+ {
+ AdviceCategory.PERFORMANCE,
+ },
+ [AdviceCategory.PERFORMANCE],
+ ],
+ [
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE},
+ [AdviceCategory.PERFORMANCE, AdviceCategory.COMPATIBILITY],
+ ],
+ ],
+)
+def test_execution_context_category_enabled(
+ context_advice_category: set[AdviceCategory],
+ expected_enabled_categories: list[AdviceCategory],
+) -> None:
+ """Test category enabled method of execution context."""
+ for category in expected_enabled_categories:
+ assert ExecutionContext(
+ advice_category=context_advice_category
+ ).category_enabled(category)
+
+
def test_execution_context(tmpdir: str) -> None:
"""Test execution context."""
publisher = DefaultEventPublisher()
- category = AdviceCategory.OPERATORS
+ category = {AdviceCategory.COMPATIBILITY}
context = ExecutionContext(
advice_category=category,
@@ -35,13 +71,13 @@ def test_execution_context(tmpdir: str) -> None:
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
- "advice_category=OPERATORS, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
"verbose=True"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
- assert context_with_default_params.advice_category is AdviceCategory.ALL
+ assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY}
assert context_with_default_params.config_parameters is None
assert context_with_default_params.event_handlers is None
assert isinstance(
@@ -55,7 +91,7 @@ def test_execution_context(tmpdir: str) -> None:
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
- "advice_category=ALL, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
"verbose=False"
)