From baaf4de286762c1955c874f78cd802d4703a8ba5 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Thu, 22 Jun 2023 14:35:21 +0100 Subject: Re-factoring of rewrite management & added metrics - List available rewrites - Refactor/rename 'Rewrite' class to 'RewritingOptimizer' - Introduce a registry for rewrite functions - Refactor 'Rewriter' to use the registry to look up rewrite functions - Remove mentions of hardcoded "fully_connected" from CLI help and error messages, using the registry instead - Add unit tests - Enable rewrites for all targets: Extract optimization (including rewrite specific code) from the Ethos-U-specific data collector into OptimizingDataCollector. This is reused in other targets' collectors, such as TOSA and Cortex-A. - Add more logging for rewrite - add display of MAE and NRMSE values for the trained result - add total model MAE and NRMSE metric Resolves: MLIA-891, MLIA-899, MLIA-906 Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f Signed-off-by: Benjamin Klimczak --- tests/test_cli_commands.py | 6 ++- tests/test_common_optimization.py | 67 +++++++++++++++++++++++ tests/test_nn_rewrite_core_rewrite.py | 80 ++++++++++++++++++++++++++-- tests/test_nn_rewrite_core_train.py | 2 +- tests/test_target_cortex_a_advisor.py | 24 ++++++--- tests/test_target_ethos_u_data_collection.py | 62 +++++++++++++-------- tests/test_target_tosa_advisor.py | 22 ++++++-- tests/test_utils_registry.py | 3 ++ 8 files changed, 227 insertions(+), 39 deletions(-) create mode 100644 tests/test_common_optimization.py (limited to 'tests') diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index e4bbe91..6b1f19d 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -3,6 +3,7 @@ """Tests for cli.commands module.""" from __future__ import annotations +import re from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any @@ -116,7 +117,10 @@ def test_performance_unknown_target( "node_y", pytest.raises( Exception, - match=(r"Currently only remove and fully_connected are supported."), + match=re.escape( + "Invalid rewrite target: 'random'. " + "Supported rewrites: ['fully_connected']" + ), ), ], [ diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py new file mode 100644 index 0000000..599610d --- /dev/null +++ b/tests/test_common_optimization.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the common optimization module.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.context import ExecutionContext +from mlia.nn.common import Optimizer +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.config import TargetProfile + + +class FakeOptimizer(Optimizer): + """Optimizer for testing purposes.""" + + def __init__(self, optimized_model_path: Path) -> None: + """Initialize.""" + super().__init__() + self.optimized_model_path = optimized_model_path + self.invocation_count = 0 + + def apply_optimization(self) -> None: + """Count the invocations.""" + self.invocation_count += 1 + + def get_model(self) -> TFLiteModel: + """Return optimized model.""" + return TFLiteModel(self.optimized_model_path) + + def optimization_config(self) -> str: + """Return something: doesn't matter, not used.""" + return "" + + +def test_optimizing_data_collector( + test_keras_model: Path, + test_tflite_model: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test OptimizingDataCollector, base support for various targets.""" + optimizations = [ + [ + {"optimization_type": "fake", "optimization_target": 42}, + ] + ] + context = ExecutionContext( + config_parameters={"common_optimizations": {"optimizations": optimizations}} + ) + + target_profile = MagicMock(spec=TargetProfile) + + fake_optimizer = FakeOptimizer(test_tflite_model) + + monkeypatch.setattr( + "mlia.target.common.optimization.get_optimizer", + MagicMock(return_value=fake_optimizer), + ) + + collector = OptimizingDataCollector(test_keras_model, target_profile) + + collector.set_context(context) + collector.collect_data() + + assert fake_optimizer.invocation_count == 1 diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 2542db2..d4aac56 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -6,15 +6,35 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import cast import pytest +from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import Rewrite +from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration -from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.rewrite.core.rewrite import RewriteRegistry +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import TestTrainingParameters +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + +def test_rewrite() -> None: + """Test the Rewrite class.""" + + def bad_rewrite_func() -> Any: + raise NotImplementedError() + + rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + with pytest.raises(RuntimeError): + rewrite((1, 2), (1, 2)) + + @pytest.mark.parametrize( "rewrite_name, expected_error", [ @@ -35,9 +55,9 @@ def test_rewrite_configuration( assert config_obj.optimization_target in str(config_obj) - rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name - assert isinstance(rewriter_obj, Rewriter) + assert isinstance(rewriter_obj, RewritingOptimizer) def test_rewriting_optimizer( @@ -52,8 +72,60 @@ def test_rewriting_optimizer( train_params=TestTrainingParameters(), ) - test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) test_obj.apply_optimization() trained_model = test_obj.get_model() assert isinstance(trained_model, TFLiteModel) + + cfg = test_obj.optimization_config() + assert isinstance(cfg, str) + assert cfg + + +def test_register_rewrite_function() -> None: + """Test adding rewrite functions and verify the are reported via the registry.""" + registry = RewriteRegistry() + + rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + + registry.register_rewrite(rewrite1) + registry.register_rewrite(rewrite2) + assert registry.names() == ["r1", "r2"] + + +def test_builtin_rewrite_names() -> None: + """Test if all builtin rewrites are properly registered and returned.""" + assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"] + + +def test_rewrite_function_autoload() -> None: + """Test rewrite function loading.""" + function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" + rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) + assert rewrite.name == "mock_rewrite" + + assert rewrite.function is not mock_rewrite_function + assert rewrite.load_function(function_name) is mock_rewrite_function + assert rewrite.function is mock_rewrite_function + + +def test_rewrite_function_autoload_fail() -> None: + """Test rewrite function loading failure.""" + function_name = "invalid_module.invalid_function" + rewrite = DynamicallyLoadedRewrite( + name="mock_rewrite", + function_name="invalid_module.invalid_function", + ) + assert rewrite.name == "mock_rewrite" + + with pytest.raises(Exception) as exc_info: + rewrite.load_function(function_name) + + message = exc_info.value.args[0] + + assert message == ( + "Unable to load rewrite function 'invalid_module.invalid_function'" + " for 'mock_rewrite'." + ) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 4493671..b001a09 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -62,7 +62,7 @@ def check_train( train_params=train_params, ) assert len(result) == 2 - assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" + assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" assert output_file.is_file() diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 9e0082f..6e370d6 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -31,7 +31,23 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: "cortex_a_inference_advisor": { "model": str(test_tflite_model), "target_profile": "cortex-a", - } + }, + "common_optimizations": { + "optimizations": [ + [ + { + "layers_to_optimize": None, + "optimization_target": 0.5, + "optimization_type": "pruning", + }, + { + "layers_to_optimize": None, + "optimization_target": 32, + "optimization_type": "clustering", + }, + ] + ] + }, } assert isinstance(workflow, DefaultWorkflowExecutor) @@ -43,11 +59,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: [ AdviceCategory.PERFORMANCE, "Performance estimation is currently not supported for Cortex-A.", - ], - [ - AdviceCategory.OPTIMIZATION, - "Model optimizations are currently not supported for Cortex-A.", - ], + ] ], ) def test_unsupported_advice_categories( diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py index 6244f8b..be93c26 100644 --- a/tests/test_target_ethos_u_data_collection.py +++ b/tests/test_target_ethos_u_data_collection.py @@ -8,9 +8,11 @@ import pytest from mlia.backend.vela.compat import Operators from mlia.core.context import Context +from mlia.core.context import ExecutionContext from mlia.core.data_collection import DataCollector from mlia.core.errors import FunctionalityNotSupportedError from mlia.nn.select import OptimizationSettings +from mlia.target.common.optimization import add_common_optimization_params from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance @@ -46,6 +48,20 @@ def test_collectors_metadata( assert collector.name() == expected_name +def setup_optimization(optimizations: list) -> Context: + """Set up optimization params for the context.""" + params: dict = {} + add_common_optimization_params( + params, + { + "optimization_targets": optimizations, + }, + ) + + context = ExecutionContext(config_parameters=params) + return context + + def test_operator_compatibility_collector( sample_context: Context, test_tflite_model: Path ) -> None: @@ -76,7 +92,6 @@ def test_performance_collector( def test_optimization_performance_collector( monkeypatch: pytest.MonkeyPatch, - sample_context: Context, test_keras_model: Path, test_tflite_model: Path, ) -> None: @@ -84,16 +99,14 @@ def test_optimization_performance_collector( target = EthosUConfiguration.load_profile("ethos-u55-256") mock_performance_estimation(monkeypatch, target) - collector = EthosUOptimizationPerformance( - test_keras_model, - target, + + context = setup_optimization( [ - [ - {"optimization_type": "pruning", "optimization_target": 0.5}, - ] + {"optimization_type": "pruning", "optimization_target": 0.5}, ], ) - collector.set_context(sample_context) + collector = EthosUOptimizationPerformance(test_keras_model, target) + collector.set_context(context) result = collector.collect_data() assert isinstance(result, OptimizationPerformanceMetrics) @@ -105,34 +118,39 @@ def test_optimization_performance_collector( assert opt == [OptimizationSettings("pruning", 0.5, None)] assert isinstance(metrics, PerformanceMetrics) - collector_no_optimizations = EthosUOptimizationPerformance( - test_keras_model, - target, - [], + context = ExecutionContext( + config_parameters={"common_optimizations": {"optimizations": [[]]}} ) + + collector_no_optimizations = EthosUOptimizationPerformance(test_keras_model, target) + collector_no_optimizations.set_context(context) with pytest.raises(FunctionalityNotSupportedError): collector_no_optimizations.collect_data() - collector_tflite = EthosUOptimizationPerformance( - test_tflite_model, - target, + context = setup_optimization( [ - [ - {"optimization_type": "pruning", "optimization_target": 0.5}, - ] + {"optimization_type": "pruning", "optimization_target": 0.5}, ], ) - collector_tflite.set_context(sample_context) + + collector_tflite = EthosUOptimizationPerformance(test_tflite_model, target) + collector_tflite.set_context(context) with pytest.raises(FunctionalityNotSupportedError): collector_tflite.collect_data() with pytest.raises( Exception, match="Optimization parameters expected to be a list" ): - collector_bad_config = EthosUOptimizationPerformance( - test_keras_model, target, {"optimization_type": "pruning"} # type: ignore + context = ExecutionContext( + config_parameters={ + "common_optimizations": { + "optimizations": [{"optimization_type": "pruning"}] + } + } ) - collector.set_context(sample_context) + + collector_bad_config = EthosUOptimizationPerformance(test_keras_model, target) + collector_bad_config.set_context(context) collector_bad_config.collect_data() diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index f4f1e36..36e52e9 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -32,10 +32,26 @@ def test_configure_and_get_tosa_advisor( assert advisor.get_events(ctx) == get_events_mock assert ctx.event_handlers is not None assert ctx.config_parameters == { + "common_optimizations": { + "optimizations": [ + [ + { + "layers_to_optimize": None, + "optimization_target": 0.5, + "optimization_type": "pruning", + }, + { + "layers_to_optimize": None, + "optimization_target": 32, + "optimization_type": "clustering", + }, + ] + ] + }, "tosa_inference_advisor": { "model": str(test_tflite_model), "target_profile": "tosa", - } + }, } assert isinstance(workflow, DefaultWorkflowExecutor) @@ -48,10 +64,6 @@ def test_configure_and_get_tosa_advisor( AdviceCategory.PERFORMANCE, "Performance estimation is currently not supported for TOSA.", ], - [ - AdviceCategory.OPTIMIZATION, - "Model optimizations are currently not supported for TOSA.", - ], ], ) def test_unsupported_advice_categories( diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py index 95721fc..288c825 100644 --- a/tests/test_utils_registry.py +++ b/tests/test_utils_registry.py @@ -8,7 +8,9 @@ def test_registry() -> None: """Test Registry class.""" reg = Registry[str]() assert not str(reg) + assert reg.names() == [] assert reg.register("name", "value") + assert reg.names() == ["name"] assert not reg.register("name", "value") assert "name" in reg.items assert reg.items["name"] == "value" @@ -17,3 +19,4 @@ def test_registry() -> None: assert len(reg.items) == 2 assert "other_name" in reg.items assert reg.items["other_name"] == "value_2" + assert reg.names() == ["name", "other_name"] -- cgit v1.2.1