From 5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Tue, 26 Mar 2024 13:20:42 +0000 Subject: fix: Update rewrite target name - Rename "fully_connected" to "fully-connected" - This will resolve issues with upstreaming rewrite library changes Signed-off-by: Madeleine Dunn Change-Id: I2f24ae4917a556fd0bd44f0db6ee4e0f7a68cd24 --- README.md | 4 ++-- src/mlia/nn/rewrite/core/rewrite.py | 4 ++-- tests/test_cli_commands.py | 6 +++--- tests/test_nn_rewrite_core_rewrite.py | 14 +++++++------- tests/test_nn_select.py | 10 +++++----- tests/test_target_ethos_u_advisor.py | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index e24dded..7d08a16 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \ --target-profile ethos-u55-256 \ --rewrite \ --dataset input.tfrec \ - --rewrite-target fully_connected \ + --rewrite-target fully-connected \ --rewrite-start MobileNet/avg_pool/AvgPool \ --rewrite-end MobileNet/fc1/BiasAdd ``` @@ -226,7 +226,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \ --optimization-profile optimization \ --rewrite \ --dataset input.tfrec \ - --rewrite-target fully_connected \ + --rewrite-target fully-connected \ --rewrite-start MobileNet/avg_pool/AvgPool \ --rewrite-end MobileNet/fc1/BiasAdd_ ``` diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index fdfd35c..8658991 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations @@ -113,7 +113,7 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ DynamicallyLoadedRewrite( - "fully_connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" + "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" ) ] ) diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 1a9bbb8..9cda27c 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -77,7 +77,7 @@ def test_performance_unknown_target( None, None, True, - "fully_connected", + "fully-connected", "sequential/flatten/Reshape", "StatefulPartitionedCall:0", does_not_raise(), @@ -90,7 +90,7 @@ def test_performance_unknown_target( 0.5, None, True, - "fully_connected", + "fully-connected", "sequential/flatten/Reshape", "StatefulPartitionedCall:0", pytest.raises( @@ -126,7 +126,7 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['fully_connected']" + "Supported rewrites: ['fully-connected']" ), ), ], diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 363d614..b32fafd 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations @@ -41,14 +41,14 @@ def test_rewrite() -> None: @pytest.mark.parametrize( "rewrite_name, expected_error", [ - ("fully_connected", does_not_raise()), + ("fully-connected", does_not_raise()), ("random", does_not_raise()), ], ) def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite type fully_connected.""" + """Test get_rewrite function only supports rewrite type fully-connected.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -67,9 +67,9 @@ def test_rewriting_optimizer( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, ) -> None: - """Test fc_layer rewrite process with rewrite type fully_connected.""" + """Test fc_layer rewrite process with rewrite type fully-connected.""" config_obj = RewriteConfiguration( - "fully_connected", + "fully-connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, train_params=MockTrainingParameters(), @@ -100,7 +100,7 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" - assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"] + assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"] def test_rewrite_function_autoload() -> None: @@ -146,7 +146,7 @@ def test_rewrite_configuration_train_params( ) config_obj = RewriteConfiguration( - "fully_connected", + "fully-connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, train_params=train_params, diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 92b7a3d..15abf2d 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -143,19 +143,19 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration ( OptimizationSettings( optimization_type="rewrite", - optimization_target="fully_connected", # type: ignore + optimization_target="fully-connected", # type: ignore layers_to_optimize=None, dataset=None, ), does_not_raise(), RewritingOptimizer, - "rewrite: fully_connected", + "rewrite: fully-connected", ), ( - RewriteConfiguration("fully_connected"), + RewriteConfiguration("fully-connected"), does_not_raise(), RewritingOptimizer, - "rewrite: fully_connected", + "rewrite: fully-connected", ), ], ) @@ -192,7 +192,7 @@ def test_get_optimizer_training_parameters( """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( optimization_type="rewrite", - optimization_target="fully_connected", # type: ignore + optimization_target="fully-connected", # type: ignore layers_to_optimize=None, dataset=None, ) diff --git a/tests/test_target_ethos_u_advisor.py b/tests/test_target_ethos_u_advisor.py index 20131d2..c5e619b 100644 --- a/tests/test_target_ethos_u_advisor.py +++ b/tests/test_target_ethos_u_advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U MLIA module.""" from __future__ import annotations @@ -40,7 +40,7 @@ def test_advisor_metadata() -> None: [ { "optimization_type": "rewrite", - "optimization_target": "fully_connected", + "optimization_target": "fully-connected", "layers_to_optimize": [ "MobileNet/avg_pool/AvgPool", "MobileNet/fc1/BiasAdd", -- cgit v1.2.1