1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from mlia.api import get_advice
from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
from mlia.devices.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
get_advice(None, test_keras_model, "all") # type: ignore
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
"""Test getting advice when wrong advice category provided."""
with pytest.raises(Exception, match="Invalid advice category unknown"):
get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
@pytest.mark.parametrize(
"category, context, expected_category",
[
[
"all",
None,
AdviceCategory.ALL,
],
[
"optimization",
None,
AdviceCategory.OPTIMIZATION,
],
[
"operators",
None,
AdviceCategory.OPERATORS,
],
[
"performance",
None,
AdviceCategory.PERFORMANCE,
],
[
"all",
ExecutionContext(),
AdviceCategory.ALL,
],
[
"all",
ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
AdviceCategory.ALL,
],
[
"all",
ExecutionContext(config_parameters={"param": "value"}),
AdviceCategory.ALL,
],
[
"all",
ExecutionContext(event_handlers=[MagicMock()]),
AdviceCategory.ALL,
],
],
)
def test_get_advice(
monkeypatch: pytest.MonkeyPatch,
category: str,
context: ExecutionContext,
expected_category: AdviceCategory,
test_keras_model: Path,
) -> None:
"""Test getting advice with valid parameters."""
advisor_mock = MagicMock()
monkeypatch.setattr("mlia.api.get_advisor", MagicMock(return_value=advisor_mock))
get_advice(
"ethos-u55-256",
test_keras_model,
category, # type: ignore
context=context,
)
advisor_mock.run.assert_called_once()
context = advisor_mock.run.mock_calls[0].args[0]
assert isinstance(context, Context)
assert context.advice_category == expected_category
def test_get_advisor(
test_keras_model: Path,
) -> None:
"""Test function for getting the advisor."""
ethos_u55_advisor = get_advisor(
ExecutionContext(), "ethos-u55-256", str(test_keras_model)
)
assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor)
tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
|