diff options
author | Gergely Nagy <gergely.nagy@arm.com> | 2023-06-22 14:35:21 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 16:16:11 +0100 |
commit | baaf4de286762c1955c874f78cd802d4703a8ba5 (patch) | |
tree | 3b80f906672f91e7e24723720b2d164d360f3edf /tests/test_nn_rewrite_core_rewrite.py | |
parent | 3cd84481fa25e64c29e57396d4bf32d7a3ca490a (diff) | |
download | mlia-baaf4de286762c1955c874f78cd802d4703a8ba5.tar.gz |
Re-factoring of rewrite management & added metrics
- List available rewrites
- Refactor/rename 'Rewrite' class to 'RewritingOptimizer'
- Introduce a registry for rewrite functions
- Refactor 'Rewriter' to use the registry to look up rewrite functions
- Remove mentions of hardcoded "fully_connected" from CLI help and
error messages, using the registry instead
- Add unit tests
- Enable rewrites for all targets:
Extract optimization (including rewrite specific code) from the
Ethos-U-specific data collector into OptimizingDataCollector.
This is reused in other targets' collectors, such as TOSA
and Cortex-A.
- Add more logging for rewrite
- add display of MAE and NRMSE values for the trained result
- add total model MAE and NRMSE metric
Resolves: MLIA-891, MLIA-899, MLIA-906
Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 80 |
1 files changed, 76 insertions, 4 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 2542db2..d4aac56 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -6,15 +6,35 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import cast import pytest +from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import Rewrite +from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration -from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.rewrite.core.rewrite import RewriteRegistry +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import TestTrainingParameters +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + +def test_rewrite() -> None: + """Test the Rewrite class.""" + + def bad_rewrite_func() -> Any: + raise NotImplementedError() + + rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + with pytest.raises(RuntimeError): + rewrite((1, 2), (1, 2)) + + @pytest.mark.parametrize( "rewrite_name, expected_error", [ @@ -35,9 +55,9 @@ def test_rewrite_configuration( assert config_obj.optimization_target in str(config_obj) - rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name - assert isinstance(rewriter_obj, Rewriter) + assert isinstance(rewriter_obj, RewritingOptimizer) def test_rewriting_optimizer( @@ -52,8 +72,60 @@ def test_rewriting_optimizer( train_params=TestTrainingParameters(), ) - test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) test_obj.apply_optimization() trained_model = test_obj.get_model() assert isinstance(trained_model, TFLiteModel) + + cfg = test_obj.optimization_config() + assert isinstance(cfg, str) + assert cfg + + +def test_register_rewrite_function() -> None: + """Test adding rewrite functions and verify the are reported via the registry.""" + registry = RewriteRegistry() + + rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + + registry.register_rewrite(rewrite1) + registry.register_rewrite(rewrite2) + assert registry.names() == ["r1", "r2"] + + +def test_builtin_rewrite_names() -> None: + """Test if all builtin rewrites are properly registered and returned.""" + assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"] + + +def test_rewrite_function_autoload() -> None: + """Test rewrite function loading.""" + function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" + rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) + assert rewrite.name == "mock_rewrite" + + assert rewrite.function is not mock_rewrite_function + assert rewrite.load_function(function_name) is mock_rewrite_function + assert rewrite.function is mock_rewrite_function + + +def test_rewrite_function_autoload_fail() -> None: + """Test rewrite function loading failure.""" + function_name = "invalid_module.invalid_function" + rewrite = DynamicallyLoadedRewrite( + name="mock_rewrite", + function_name="invalid_module.invalid_function", + ) + assert rewrite.name == "mock_rewrite" + + with pytest.raises(Exception) as exc_info: + rewrite.load_function(function_name) + + message = exc_info.value.args[0] + + assert message == ( + "Unable to load rewrite function 'invalid_module.invalid_function'" + " for 'mock_rewrite'." + ) |