aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_ethos_u_data_collection.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_target_ethos_u_data_collection.py')
-rw-r--r--tests/test_target_ethos_u_data_collection.py62
1 files changed, 40 insertions, 22 deletions
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py
index 6244f8b..be93c26 100644
--- a/tests/test_target_ethos_u_data_collection.py
+++ b/tests/test_target_ethos_u_data_collection.py
@@ -8,9 +8,11 @@ import pytest
from mlia.backend.vela.compat import Operators
from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
from mlia.core.data_collection import DataCollector
from mlia.core.errors import FunctionalityNotSupportedError
from mlia.nn.select import OptimizationSettings
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -46,6 +48,20 @@ def test_collectors_metadata(
assert collector.name() == expected_name
+def setup_optimization(optimizations: list) -> Context:
+ """Set up optimization params for the context."""
+ params: dict = {}
+ add_common_optimization_params(
+ params,
+ {
+ "optimization_targets": optimizations,
+ },
+ )
+
+ context = ExecutionContext(config_parameters=params)
+ return context
+
+
def test_operator_compatibility_collector(
sample_context: Context, test_tflite_model: Path
) -> None:
@@ -76,7 +92,6 @@ def test_performance_collector(
def test_optimization_performance_collector(
monkeypatch: pytest.MonkeyPatch,
- sample_context: Context,
test_keras_model: Path,
test_tflite_model: Path,
) -> None:
@@ -84,16 +99,14 @@ def test_optimization_performance_collector(
target = EthosUConfiguration.load_profile("ethos-u55-256")
mock_performance_estimation(monkeypatch, target)
- collector = EthosUOptimizationPerformance(
- test_keras_model,
- target,
+
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector.set_context(sample_context)
+ collector = EthosUOptimizationPerformance(test_keras_model, target)
+ collector.set_context(context)
result = collector.collect_data()
assert isinstance(result, OptimizationPerformanceMetrics)
@@ -105,34 +118,39 @@ def test_optimization_performance_collector(
assert opt == [OptimizationSettings("pruning", 0.5, None)]
assert isinstance(metrics, PerformanceMetrics)
- collector_no_optimizations = EthosUOptimizationPerformance(
- test_keras_model,
- target,
- [],
+ context = ExecutionContext(
+ config_parameters={"common_optimizations": {"optimizations": [[]]}}
)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_no_optimizations.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_no_optimizations.collect_data()
- collector_tflite = EthosUOptimizationPerformance(
- test_tflite_model,
- target,
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector_tflite.set_context(sample_context)
+
+ collector_tflite = EthosUOptimizationPerformance(test_tflite_model, target)
+ collector_tflite.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_tflite.collect_data()
with pytest.raises(
Exception, match="Optimization parameters expected to be a list"
):
- collector_bad_config = EthosUOptimizationPerformance(
- test_keras_model, target, {"optimization_type": "pruning"} # type: ignore
+ context = ExecutionContext(
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": [{"optimization_type": "pruning"}]
+ }
+ }
)
- collector.set_context(sample_context)
+
+ collector_bad_config = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_bad_config.set_context(context)
collector_bad_config.collect_data()