1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module context."""
from pathlib import Path
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.events import DefaultEventPublisher
def test_execution_context(tmpdir: str) -> None:
"""Test execution context."""
publisher = DefaultEventPublisher()
category = AdviceCategory.OPERATORS
context = ExecutionContext(
advice_category=category,
config_parameters={"param": "value"},
working_dir=tmpdir,
event_handlers=[],
event_publisher=publisher,
verbose=True,
logs_dir="logs_directory",
models_dir="models_directory",
)
assert context.advice_category == category
assert context.config_parameters == {"param": "value"}
assert context.event_handlers == []
assert context.event_publisher == publisher
assert context.logs_path == Path(tmpdir) / "logs_directory"
expected_model_path = Path(tmpdir) / "models_directory/sample.model"
assert context.get_model_path("sample.model") == expected_model_path
assert context.verbose is True
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
"advice_category=OPERATORS, "
"config_parameters={'param': 'value'}, "
"verbose=True"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
assert context_with_default_params.advice_category is None
assert context_with_default_params.config_parameters is None
assert context_with_default_params.event_handlers is None
assert isinstance(
context_with_default_params.event_publisher, DefaultEventPublisher
)
assert context_with_default_params.logs_path == Path(tmpdir) / "logs"
default_model_path = context_with_default_params.get_model_path("sample.model")
expected_default_model_path = Path(tmpdir) / "models/sample.model"
assert default_model_path == expected_default_model_path
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
"advice_category=<not set>, "
"config_parameters=None, "
"verbose=False"
)
assert str(context_with_default_params) == expected_str
|