aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_context.py
blob: 10015aac7b76682a1e86116bb48e52b163bb1a6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module context."""
from pathlib import Path

from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.events import DefaultEventPublisher


def test_execution_context(tmpdir: str) -> None:
    """Test execution context."""
    publisher = DefaultEventPublisher()
    category = AdviceCategory.OPERATORS

    context = ExecutionContext(
        advice_category=category,
        config_parameters={"param": "value"},
        working_dir=tmpdir,
        event_handlers=[],
        event_publisher=publisher,
        verbose=True,
        logs_dir="logs_directory",
        models_dir="models_directory",
    )

    assert context.advice_category == category
    assert context.config_parameters == {"param": "value"}
    assert context.event_handlers == []
    assert context.event_publisher == publisher
    assert context.logs_path == Path(tmpdir) / "logs_directory"
    expected_model_path = Path(tmpdir) / "models_directory/sample.model"
    assert context.get_model_path("sample.model") == expected_model_path
    assert context.verbose is True
    assert str(context) == (
        f"ExecutionContext: "
        f"working_dir={tmpdir}, "
        "advice_category=OPERATORS, "
        "config_parameters={'param': 'value'}, "
        "verbose=True"
    )

    context_with_default_params = ExecutionContext(working_dir=tmpdir)
    assert context_with_default_params.advice_category is None
    assert context_with_default_params.config_parameters is None
    assert context_with_default_params.event_handlers is None
    assert isinstance(
        context_with_default_params.event_publisher, DefaultEventPublisher
    )
    assert context_with_default_params.logs_path == Path(tmpdir) / "logs"

    default_model_path = context_with_default_params.get_model_path("sample.model")
    expected_default_model_path = Path(tmpdir) / "models/sample.model"
    assert default_model_path == expected_default_model_path

    expected_str = (
        f"ExecutionContext: working_dir={tmpdir}, "
        "advice_category=<not set>, "
        "config_parameters=None, "
        "verbose=False"
    )
    assert str(context_with_default_params) == expected_str