1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U MLIA module."""
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
import pytest
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
from mlia.target.ethos_u.advisor import EthosUInferenceAdvisor
def test_advisor_metadata() -> None:
"""Test advisor metadata."""
assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor"
@pytest.mark.parametrize(
"optimization_targets, expected_error",
[
[
[
{
"optimization_type": "pruning",
"optimization_target": 0.5,
"layers_to_optimize": None,
}
],
pytest.raises(
Exception,
match="Only 'rewrite' is supported for TensorFlow Lite files.",
),
],
[
[
{
"optimization_type": "rewrite",
"optimization_target": "fully_connected",
"layers_to_optimize": [
"MobileNet/avg_pool/AvgPool",
"MobileNet/fc1/BiasAdd",
],
}
],
does_not_raise(),
],
],
)
def test_unsupported_advice_categories(
tmp_path: Path,
test_tflite_model: Path,
optimization_targets: list[dict[str, Any]],
expected_error: Any,
) -> None:
"""Test that advisor should throw an exception for unsupported categories."""
with expected_error:
ctx = ExecutionContext(
output_dir=tmp_path, advice_category={AdviceCategory.OPTIMIZATION}
)
advisor = configure_and_get_ethosu_advisor(
ctx,
"ethos-u55-256",
str(test_tflite_model),
optimization_targets=optimization_targets,
)
advisor.configure(ctx)
|