aboutsummaryrefslogtreecommitdiff
path: root/tests/test_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_api.py')
-rw-r--r--tests/test_api.py98
1 files changed, 27 insertions, 71 deletions
diff --git a/tests/test_api.py b/tests/test_api.py
index fbc558b..0bbc3ae 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,15 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
-from unittest.mock import patch
import pytest
-from mlia.api import generate_supported_operators_report
from mlia.api import get_advice
from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
@@ -22,63 +20,68 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, "all") # type: ignore
+ get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
"""Test getting advice when wrong advice category provided."""
with pytest.raises(Exception, match="Invalid advice category unknown"):
- get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+ get_advice("ethos-u55-256", test_keras_model, {"unknown"})
@pytest.mark.parametrize(
"category, context, expected_category",
[
[
- "all",
+ {"compatibility", "optimization"},
None,
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "optimization",
+ {"optimization"},
None,
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
],
[
- "operators",
+ {"compatibility"},
None,
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
],
[
- "performance",
+ {"performance"},
None,
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
- ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
- AdviceCategory.ALL,
+ {"compatibility", "optimization"},
+ ExecutionContext(
+ advice_category={
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.OPTIMIZATION,
+ }
+ ),
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(config_parameters={"param": "value"}),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(event_handlers=[MagicMock()]),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
],
)
def test_get_advice(
monkeypatch: pytest.MonkeyPatch,
- category: str,
+ category: set[str],
context: ExecutionContext,
expected_category: AdviceCategory,
test_keras_model: Path,
@@ -90,7 +93,7 @@ def test_get_advice(
get_advice(
"ethos-u55-256",
test_keras_model,
- category, # type: ignore
+ category,
context=context,
)
@@ -111,50 +114,3 @@ def test_get_advisor(
tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
-
-
-@pytest.mark.parametrize(
- ["target_profile", "required_calls", "exception_msg"],
- [
- [
- "ethos-u55-128",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "ethos-u65-256",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "tosa",
- None,
- "Generating a supported operators report is not "
- "currently supported with TOSA target profile.",
- ],
- [
- "cortex-a",
- None,
- "Generating a supported operators report is not "
- "currently supported with Cortex-A target profile.",
- ],
- [
- "Unknown",
- None,
- "Unable to find target profile Unknown",
- ],
- ],
-)
-def test_supported_ops_report_generator(
- target_profile: str, required_calls: str | None, exception_msg: str | None
-) -> None:
- """Test supported operators report generator with different target profiles."""
- if exception_msg:
- with pytest.raises(Exception) as exc:
- generate_supported_operators_report(target_profile)
- assert str(exc.value) == exception_msg
-
- if required_calls:
- with patch(required_calls) as mock_method:
- generate_supported_operators_report(target_profile)
- mock_method.assert_called_once()