aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMadeleine Dunn <madeleine.dunn@arm.com>2024-03-26 13:20:42 +0000
committerMadeleine Dunn <madeleine.dunn@arm.com>2024-03-27 10:45:34 +0000
commit5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (patch)
tree47ff3c52dfca780a69d2eeba037e1602fe50b655
parentc7ee5b783f044d7ff641773aa385840f5ff944cc (diff)
downloadmlia-5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d.tar.gz
fix: Update rewrite target name
- Rename "fully_connected" to "fully-connected" - This will resolve issues with upstreaming rewrite library changes Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I2f24ae4917a556fd0bd44f0db6ee4e0f7a68cd24
-rw-r--r--README.md4
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py4
-rw-r--r--tests/test_cli_commands.py6
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py14
-rw-r--r--tests/test_nn_select.py10
-rw-r--r--tests/test_target_ethos_u_advisor.py4
6 files changed, 21 insertions, 21 deletions
diff --git a/README.md b/README.md
index e24dded..7d08a16 100644
--- a/README.md
+++ b/README.md
@@ -204,7 +204,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \
--target-profile ethos-u55-256 \
--rewrite \
--dataset input.tfrec \
- --rewrite-target fully_connected \
+ --rewrite-target fully-connected \
--rewrite-start MobileNet/avg_pool/AvgPool \
--rewrite-end MobileNet/fc1/BiasAdd
```
@@ -226,7 +226,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \
--optimization-profile optimization \
--rewrite \
--dataset input.tfrec \
- --rewrite-target fully_connected \
+ --rewrite-target fully-connected \
--rewrite-start MobileNet/avg_pool/AvgPool \
--rewrite-end MobileNet/fc1/BiasAdd_
```
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index fdfd35c..8658991 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
@@ -113,7 +113,7 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
DynamicallyLoadedRewrite(
- "fully_connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
+ "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
)
]
)
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 1a9bbb8..9cda27c 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -77,7 +77,7 @@ def test_performance_unknown_target(
None,
None,
True,
- "fully_connected",
+ "fully-connected",
"sequential/flatten/Reshape",
"StatefulPartitionedCall:0",
does_not_raise(),
@@ -90,7 +90,7 @@ def test_performance_unknown_target(
0.5,
None,
True,
- "fully_connected",
+ "fully-connected",
"sequential/flatten/Reshape",
"StatefulPartitionedCall:0",
pytest.raises(
@@ -126,7 +126,7 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['fully_connected']"
+ "Supported rewrites: ['fully-connected']"
),
),
],
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 363d614..b32fafd 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
@@ -41,14 +41,14 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
- ("fully_connected", does_not_raise()),
+ ("fully-connected", does_not_raise()),
("random", does_not_raise()),
],
)
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite type fully_connected."""
+ """Test get_rewrite function only supports rewrite type fully-connected."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -67,9 +67,9 @@ def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
) -> None:
- """Test fc_layer rewrite process with rewrite type fully_connected."""
+ """Test fc_layer rewrite process with rewrite type fully-connected."""
config_obj = RewriteConfiguration(
- "fully_connected",
+ "fully-connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
train_params=MockTrainingParameters(),
@@ -100,7 +100,7 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
- assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"]
+ assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"]
def test_rewrite_function_autoload() -> None:
@@ -146,7 +146,7 @@ def test_rewrite_configuration_train_params(
)
config_obj = RewriteConfiguration(
- "fully_connected",
+ "fully-connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
train_params=train_params,
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 92b7a3d..15abf2d 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -143,19 +143,19 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
(
OptimizationSettings(
optimization_type="rewrite",
- optimization_target="fully_connected", # type: ignore
+ optimization_target="fully-connected", # type: ignore
layers_to_optimize=None,
dataset=None,
),
does_not_raise(),
RewritingOptimizer,
- "rewrite: fully_connected",
+ "rewrite: fully-connected",
),
(
- RewriteConfiguration("fully_connected"),
+ RewriteConfiguration("fully-connected"),
does_not_raise(),
RewritingOptimizer,
- "rewrite: fully_connected",
+ "rewrite: fully-connected",
),
],
)
@@ -192,7 +192,7 @@ def test_get_optimizer_training_parameters(
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
optimization_type="rewrite",
- optimization_target="fully_connected", # type: ignore
+ optimization_target="fully-connected", # type: ignore
layers_to_optimize=None,
dataset=None,
)
diff --git a/tests/test_target_ethos_u_advisor.py b/tests/test_target_ethos_u_advisor.py
index 20131d2..c5e619b 100644
--- a/tests/test_target_ethos_u_advisor.py
+++ b/tests/test_target_ethos_u_advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U MLIA module."""
from __future__ import annotations
@@ -40,7 +40,7 @@ def test_advisor_metadata() -> None:
[
{
"optimization_type": "rewrite",
- "optimization_target": "fully_connected",
+ "optimization_target": "fully-connected",
"layers_to_optimize": [
"MobileNet/avg_pool/AvgPool",
"MobileNet/fc1/BiasAdd",