aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_api.py
blob: 54d4796ac4519811c388416de469cedb4c840317 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from pathlib import Path
from unittest.mock import MagicMock

import pytest

from mlia.api import get_advice
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext


def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
    """Test getting advice when no target provided."""
    with pytest.raises(Exception, match="Target is not provided"):
        get_advice(None, test_keras_model, "all")  # type: ignore


def test_get_advice_wrong_category(test_keras_model: Path) -> None:
    """Test getting advice when wrong advice category provided."""
    with pytest.raises(Exception, match="Invalid advice category unknown"):
        get_advice("ethos-u55-256", test_keras_model, "unknown")  # type: ignore


@pytest.mark.parametrize(
    "category, context, expected_category",
    [
        [
            "all",
            None,
            AdviceCategory.ALL,
        ],
        [
            "optimization",
            None,
            AdviceCategory.OPTIMIZATION,
        ],
        [
            "operators",
            None,
            AdviceCategory.OPERATORS,
        ],
        [
            "performance",
            None,
            AdviceCategory.PERFORMANCE,
        ],
        [
            "all",
            ExecutionContext(),
            AdviceCategory.ALL,
        ],
        [
            "all",
            ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
            AdviceCategory.PERFORMANCE,
        ],
        [
            "all",
            ExecutionContext(config_parameters={"param": "value"}),
            AdviceCategory.ALL,
        ],
        [
            "all",
            ExecutionContext(event_handlers=[MagicMock()]),
            AdviceCategory.ALL,
        ],
    ],
)
def test_get_advice(
    monkeypatch: pytest.MonkeyPatch,
    category: str,
    context: ExecutionContext,
    expected_category: AdviceCategory,
    test_keras_model: Path,
) -> None:
    """Test getting advice with valid parameters."""
    advisor_mock = MagicMock()
    monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock))

    get_advice(
        "ethos-u55-256",
        test_keras_model,
        category,  # type: ignore
        context=context,
    )

    advisor_mock.run.assert_called_once()
    context = advisor_mock.run.mock_calls[0].args[0]
    assert isinstance(context, Context)
    assert context.advice_category == expected_category

    assert context.event_handlers is not None
    assert context.config_parameters is not None