aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGergely Nagy <gergely.nagy@arm.com>2023-06-22 14:35:21 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:16:11 +0100
commitbaaf4de286762c1955c874f78cd802d4703a8ba5 (patch)
tree3b80f906672f91e7e24723720b2d164d360f3edf /tests
parent3cd84481fa25e64c29e57396d4bf32d7a3ca490a (diff)
downloadmlia-baaf4de286762c1955c874f78cd802d4703a8ba5.tar.gz
Re-factoring of rewrite management & added metrics
- List available rewrites - Refactor/rename 'Rewrite' class to 'RewritingOptimizer' - Introduce a registry for rewrite functions - Refactor 'Rewriter' to use the registry to look up rewrite functions - Remove mentions of hardcoded "fully_connected" from CLI help and error messages, using the registry instead - Add unit tests - Enable rewrites for all targets: Extract optimization (including rewrite specific code) from the Ethos-U-specific data collector into OptimizingDataCollector. This is reused in other targets' collectors, such as TOSA and Cortex-A. - Add more logging for rewrite - add display of MAE and NRMSE values for the trained result - add total model MAE and NRMSE metric Resolves: MLIA-891, MLIA-899, MLIA-906 Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cli_commands.py6
-rw-r--r--tests/test_common_optimization.py67
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py80
-rw-r--r--tests/test_nn_rewrite_core_train.py2
-rw-r--r--tests/test_target_cortex_a_advisor.py24
-rw-r--r--tests/test_target_ethos_u_data_collection.py62
-rw-r--r--tests/test_target_tosa_advisor.py22
-rw-r--r--tests/test_utils_registry.py3
8 files changed, 227 insertions, 39 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index e4bbe91..6b1f19d 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -3,6 +3,7 @@
"""Tests for cli.commands module."""
from __future__ import annotations
+import re
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
@@ -116,7 +117,10 @@ def test_performance_unknown_target(
"node_y",
pytest.raises(
Exception,
- match=(r"Currently only remove and fully_connected are supported."),
+ match=re.escape(
+ "Invalid rewrite target: 'random'. "
+ "Supported rewrites: ['fully_connected']"
+ ),
),
],
[
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
new file mode 100644
index 0000000..599610d
--- /dev/null
+++ b/tests/test_common_optimization.py
@@ -0,0 +1,67 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the common optimization module."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+from mlia.nn.common import Optimizer
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.config import TargetProfile
+
+
+class FakeOptimizer(Optimizer):
+ """Optimizer for testing purposes."""
+
+ def __init__(self, optimized_model_path: Path) -> None:
+ """Initialize."""
+ super().__init__()
+ self.optimized_model_path = optimized_model_path
+ self.invocation_count = 0
+
+ def apply_optimization(self) -> None:
+ """Count the invocations."""
+ self.invocation_count += 1
+
+ def get_model(self) -> TFLiteModel:
+ """Return optimized model."""
+ return TFLiteModel(self.optimized_model_path)
+
+ def optimization_config(self) -> str:
+ """Return something: doesn't matter, not used."""
+ return ""
+
+
+def test_optimizing_data_collector(
+ test_keras_model: Path,
+ test_tflite_model: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test OptimizingDataCollector, base support for various targets."""
+ optimizations = [
+ [
+ {"optimization_type": "fake", "optimization_target": 42},
+ ]
+ ]
+ context = ExecutionContext(
+ config_parameters={"common_optimizations": {"optimizations": optimizations}}
+ )
+
+ target_profile = MagicMock(spec=TargetProfile)
+
+ fake_optimizer = FakeOptimizer(test_tflite_model)
+
+ monkeypatch.setattr(
+ "mlia.target.common.optimization.get_optimizer",
+ MagicMock(return_value=fake_optimizer),
+ )
+
+ collector = OptimizingDataCollector(test_keras_model, target_profile)
+
+ collector.set_context(context)
+ collector.collect_data()
+
+ assert fake_optimizer.invocation_count == 1
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 2542db2..d4aac56 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -6,15 +6,35 @@ from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
+from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
-from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import RewriteRegistry
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import TestTrainingParameters
+def mock_rewrite_function(*_: Any) -> Any:
+ """Mock function to test autoloading of rewrite functions."""
+
+
+def test_rewrite() -> None:
+ """Test the Rewrite class."""
+
+ def bad_rewrite_func() -> Any:
+ raise NotImplementedError()
+
+ rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ with pytest.raises(RuntimeError):
+ rewrite((1, 2), (1, 2))
+
+
@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
@@ -35,9 +55,9 @@ def test_rewrite_configuration(
assert config_obj.optimization_target in str(config_obj)
- rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
- assert isinstance(rewriter_obj, Rewriter)
+ assert isinstance(rewriter_obj, RewritingOptimizer)
def test_rewriting_optimizer(
@@ -52,8 +72,60 @@ def test_rewriting_optimizer(
train_params=TestTrainingParameters(),
)
- test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
test_obj.apply_optimization()
trained_model = test_obj.get_model()
assert isinstance(trained_model, TFLiteModel)
+
+ cfg = test_obj.optimization_config()
+ assert isinstance(cfg, str)
+ assert cfg
+
+
+def test_register_rewrite_function() -> None:
+ """Test adding rewrite functions and verify the are reported via the registry."""
+ registry = RewriteRegistry()
+
+ rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+
+ registry.register_rewrite(rewrite1)
+ registry.register_rewrite(rewrite2)
+ assert registry.names() == ["r1", "r2"]
+
+
+def test_builtin_rewrite_names() -> None:
+ """Test if all builtin rewrites are properly registered and returned."""
+ assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"]
+
+
+def test_rewrite_function_autoload() -> None:
+ """Test rewrite function loading."""
+ function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
+ rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
+ assert rewrite.name == "mock_rewrite"
+
+ assert rewrite.function is not mock_rewrite_function
+ assert rewrite.load_function(function_name) is mock_rewrite_function
+ assert rewrite.function is mock_rewrite_function
+
+
+def test_rewrite_function_autoload_fail() -> None:
+ """Test rewrite function loading failure."""
+ function_name = "invalid_module.invalid_function"
+ rewrite = DynamicallyLoadedRewrite(
+ name="mock_rewrite",
+ function_name="invalid_module.invalid_function",
+ )
+ assert rewrite.name == "mock_rewrite"
+
+ with pytest.raises(Exception) as exc_info:
+ rewrite.load_function(function_name)
+
+ message = exc_info.value.args[0]
+
+ assert message == (
+ "Unable to load rewrite function 'invalid_module.invalid_function'"
+ " for 'mock_rewrite'."
+ )
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 4493671..b001a09 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -62,7 +62,7 @@ def check_train(
train_params=train_params,
)
assert len(result) == 2
- assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
+ assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
assert output_file.is_file()
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 9e0082f..6e370d6 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -31,7 +31,23 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"cortex_a_inference_advisor": {
"model": str(test_tflite_model),
"target_profile": "cortex-a",
- }
+ },
+ "common_optimizations": {
+ "optimizations": [
+ [
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 0.5,
+ "optimization_type": "pruning",
+ },
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 32,
+ "optimization_type": "clustering",
+ },
+ ]
+ ]
+ },
}
assert isinstance(workflow, DefaultWorkflowExecutor)
@@ -43,11 +59,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
[
AdviceCategory.PERFORMANCE,
"Performance estimation is currently not supported for Cortex-A.",
- ],
- [
- AdviceCategory.OPTIMIZATION,
- "Model optimizations are currently not supported for Cortex-A.",
- ],
+ ]
],
)
def test_unsupported_advice_categories(
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py
index 6244f8b..be93c26 100644
--- a/tests/test_target_ethos_u_data_collection.py
+++ b/tests/test_target_ethos_u_data_collection.py
@@ -8,9 +8,11 @@ import pytest
from mlia.backend.vela.compat import Operators
from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
from mlia.core.data_collection import DataCollector
from mlia.core.errors import FunctionalityNotSupportedError
from mlia.nn.select import OptimizationSettings
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -46,6 +48,20 @@ def test_collectors_metadata(
assert collector.name() == expected_name
+def setup_optimization(optimizations: list) -> Context:
+ """Set up optimization params for the context."""
+ params: dict = {}
+ add_common_optimization_params(
+ params,
+ {
+ "optimization_targets": optimizations,
+ },
+ )
+
+ context = ExecutionContext(config_parameters=params)
+ return context
+
+
def test_operator_compatibility_collector(
sample_context: Context, test_tflite_model: Path
) -> None:
@@ -76,7 +92,6 @@ def test_performance_collector(
def test_optimization_performance_collector(
monkeypatch: pytest.MonkeyPatch,
- sample_context: Context,
test_keras_model: Path,
test_tflite_model: Path,
) -> None:
@@ -84,16 +99,14 @@ def test_optimization_performance_collector(
target = EthosUConfiguration.load_profile("ethos-u55-256")
mock_performance_estimation(monkeypatch, target)
- collector = EthosUOptimizationPerformance(
- test_keras_model,
- target,
+
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector.set_context(sample_context)
+ collector = EthosUOptimizationPerformance(test_keras_model, target)
+ collector.set_context(context)
result = collector.collect_data()
assert isinstance(result, OptimizationPerformanceMetrics)
@@ -105,34 +118,39 @@ def test_optimization_performance_collector(
assert opt == [OptimizationSettings("pruning", 0.5, None)]
assert isinstance(metrics, PerformanceMetrics)
- collector_no_optimizations = EthosUOptimizationPerformance(
- test_keras_model,
- target,
- [],
+ context = ExecutionContext(
+ config_parameters={"common_optimizations": {"optimizations": [[]]}}
)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_no_optimizations.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_no_optimizations.collect_data()
- collector_tflite = EthosUOptimizationPerformance(
- test_tflite_model,
- target,
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector_tflite.set_context(sample_context)
+
+ collector_tflite = EthosUOptimizationPerformance(test_tflite_model, target)
+ collector_tflite.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_tflite.collect_data()
with pytest.raises(
Exception, match="Optimization parameters expected to be a list"
):
- collector_bad_config = EthosUOptimizationPerformance(
- test_keras_model, target, {"optimization_type": "pruning"} # type: ignore
+ context = ExecutionContext(
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": [{"optimization_type": "pruning"}]
+ }
+ }
)
- collector.set_context(sample_context)
+
+ collector_bad_config = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_bad_config.set_context(context)
collector_bad_config.collect_data()
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index f4f1e36..36e52e9 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -32,10 +32,26 @@ def test_configure_and_get_tosa_advisor(
assert advisor.get_events(ctx) == get_events_mock
assert ctx.event_handlers is not None
assert ctx.config_parameters == {
+ "common_optimizations": {
+ "optimizations": [
+ [
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 0.5,
+ "optimization_type": "pruning",
+ },
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 32,
+ "optimization_type": "clustering",
+ },
+ ]
+ ]
+ },
"tosa_inference_advisor": {
"model": str(test_tflite_model),
"target_profile": "tosa",
- }
+ },
}
assert isinstance(workflow, DefaultWorkflowExecutor)
@@ -48,10 +64,6 @@ def test_configure_and_get_tosa_advisor(
AdviceCategory.PERFORMANCE,
"Performance estimation is currently not supported for TOSA.",
],
- [
- AdviceCategory.OPTIMIZATION,
- "Model optimizations are currently not supported for TOSA.",
- ],
],
)
def test_unsupported_advice_categories(
diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py
index 95721fc..288c825 100644
--- a/tests/test_utils_registry.py
+++ b/tests/test_utils_registry.py
@@ -8,7 +8,9 @@ def test_registry() -> None:
"""Test Registry class."""
reg = Registry[str]()
assert not str(reg)
+ assert reg.names() == []
assert reg.register("name", "value")
+ assert reg.names() == ["name"]
assert not reg.register("name", "value")
assert "name" in reg.items
assert reg.items["name"] == "value"
@@ -17,3 +19,4 @@ def test_registry() -> None:
assert len(reg.items) == 2
assert "other_name" in reg.items
assert reg.items["other_name"] == "value_2"
+ assert reg.names() == ["name", "other_name"]