aboutsummaryrefslogtreecommitdiff
path: root/tests/test_api.py
blob: b40c55b08bdc0255eb6b878aa35da112bf7c5d95 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from __future__ import annotations

from pathlib import Path
from unittest.mock import MagicMock

import pytest

from mlia.api import get_advice
from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.target.ethos_u.advisor import EthosUInferenceAdvisor
from mlia.target.tosa.advisor import TOSAInferenceAdvisor


def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
    """Test getting advice when no target provided."""
    with pytest.raises(Exception, match="No valid target profile was provided."):
        get_advice(None, test_keras_model, {"compatibility"})  # type: ignore


def test_get_advice_wrong_category(test_keras_model: Path) -> None:
    """Test getting advice when wrong advice category provided."""
    with pytest.raises(Exception, match="Invalid advice category unknown"):
        get_advice("ethos-u55-256", test_keras_model, {"unknown"})


@pytest.mark.parametrize(
    "category, context, expected_category",
    [
        [
            {"compatibility", "optimization"},
            None,
            {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
        ],
        [
            {"optimization"},
            None,
            {AdviceCategory.OPTIMIZATION},
        ],
        [
            {"compatibility"},
            None,
            {AdviceCategory.COMPATIBILITY},
        ],
        [
            {"performance"},
            None,
            {AdviceCategory.PERFORMANCE},
        ],
        [
            {"compatibility", "optimization"},
            ExecutionContext(),
            {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
        ],
        [
            {"compatibility", "optimization"},
            ExecutionContext(
                advice_category={
                    AdviceCategory.COMPATIBILITY,
                    AdviceCategory.OPTIMIZATION,
                }
            ),
            {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
        ],
        [
            {"compatibility", "optimization"},
            ExecutionContext(config_parameters={"param": "value"}),
            {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
        ],
        [
            {"compatibility", "optimization"},
            ExecutionContext(event_handlers=[MagicMock()]),
            {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
        ],
    ],
)
def test_get_advice(
    monkeypatch: pytest.MonkeyPatch,
    category: set[str],
    context: ExecutionContext,
    expected_category: AdviceCategory,
    test_keras_model: Path,
) -> None:
    """Test getting advice with valid parameters."""
    advisor_mock = MagicMock()
    monkeypatch.setattr("mlia.api.get_advisor", MagicMock(return_value=advisor_mock))

    get_advice(
        "ethos-u55-256",
        test_keras_model,
        category,
        context=context,
    )

    advisor_mock.run.assert_called_once()
    context = advisor_mock.run.mock_calls[0].args[0]
    assert isinstance(context, Context)
    assert context.advice_category == expected_category


def test_get_advisor(
    test_keras_model: Path,
) -> None:
    """Test function for getting the advisor."""
    ethos_u55_advisor = get_advisor(
        ExecutionContext(), "ethos-u55-256", str(test_keras_model)
    )
    assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor)

    tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
    assert isinstance(tosa_advisor, TOSAInferenceAdvisor)