diff options
Diffstat (limited to 'tests/test_target_ethos_u_data_collection.py')
-rw-r--r-- | tests/test_target_ethos_u_data_collection.py | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py index 6244f8b..be93c26 100644 --- a/tests/test_target_ethos_u_data_collection.py +++ b/tests/test_target_ethos_u_data_collection.py @@ -8,9 +8,11 @@ import pytest from mlia.backend.vela.compat import Operators from mlia.core.context import Context +from mlia.core.context import ExecutionContext from mlia.core.data_collection import DataCollector from mlia.core.errors import FunctionalityNotSupportedError from mlia.nn.select import OptimizationSettings +from mlia.target.common.optimization import add_common_optimization_params from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance @@ -46,6 +48,20 @@ def test_collectors_metadata( assert collector.name() == expected_name +def setup_optimization(optimizations: list) -> Context: + """Set up optimization params for the context.""" + params: dict = {} + add_common_optimization_params( + params, + { + "optimization_targets": optimizations, + }, + ) + + context = ExecutionContext(config_parameters=params) + return context + + def test_operator_compatibility_collector( sample_context: Context, test_tflite_model: Path ) -> None: @@ -76,7 +92,6 @@ def test_performance_collector( def test_optimization_performance_collector( monkeypatch: pytest.MonkeyPatch, - sample_context: Context, test_keras_model: Path, test_tflite_model: Path, ) -> None: @@ -84,16 +99,14 @@ def test_optimization_performance_collector( target = EthosUConfiguration.load_profile("ethos-u55-256") mock_performance_estimation(monkeypatch, target) - collector = EthosUOptimizationPerformance( - test_keras_model, - target, + + context = setup_optimization( [ - [ - {"optimization_type": "pruning", "optimization_target": 0.5}, - ] + {"optimization_type": "pruning", "optimization_target": 0.5}, ], ) - collector.set_context(sample_context) + collector = EthosUOptimizationPerformance(test_keras_model, target) + collector.set_context(context) result = collector.collect_data() assert isinstance(result, OptimizationPerformanceMetrics) @@ -105,34 +118,39 @@ def test_optimization_performance_collector( assert opt == [OptimizationSettings("pruning", 0.5, None)] assert isinstance(metrics, PerformanceMetrics) - collector_no_optimizations = EthosUOptimizationPerformance( - test_keras_model, - target, - [], + context = ExecutionContext( + config_parameters={"common_optimizations": {"optimizations": [[]]}} ) + + collector_no_optimizations = EthosUOptimizationPerformance(test_keras_model, target) + collector_no_optimizations.set_context(context) with pytest.raises(FunctionalityNotSupportedError): collector_no_optimizations.collect_data() - collector_tflite = EthosUOptimizationPerformance( - test_tflite_model, - target, + context = setup_optimization( [ - [ - {"optimization_type": "pruning", "optimization_target": 0.5}, - ] + {"optimization_type": "pruning", "optimization_target": 0.5}, ], ) - collector_tflite.set_context(sample_context) + + collector_tflite = EthosUOptimizationPerformance(test_tflite_model, target) + collector_tflite.set_context(context) with pytest.raises(FunctionalityNotSupportedError): collector_tflite.collect_data() with pytest.raises( Exception, match="Optimization parameters expected to be a list" ): - collector_bad_config = EthosUOptimizationPerformance( - test_keras_model, target, {"optimization_type": "pruning"} # type: ignore + context = ExecutionContext( + config_parameters={ + "common_optimizations": { + "optimizations": [{"optimization_type": "pruning"}] + } + } ) - collector.set_context(sample_context) + + collector_bad_config = EthosUOptimizationPerformance(test_keras_model, target) + collector_bad_config.set_context(context) collector_bad_config.collect_data() |