diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-04-20 09:51:20 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:44:51 +0100 |
commit | f3e6597dd50ec70f043d692b773f2d9fd31519ae (patch) | |
tree | 322ccb75e0cc594c57308288cae333a72401979e /tests/test_cli_commands.py | |
parent | 867f37d643e66c0223457c28f5345f2f21db97f2 (diff) | |
download | mlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz |
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer
* Define RewriteConfiguration and Rewriter to integrate
rewrite module into mlia optimize command
* Fix a bug in the ethos_u/data_collection.py file
* Fix a bug in join.py
* Remove diff_stats and use diff instead, added related
changes around this to ensure e2e tests passing
* Add unit tests for all changes
* Fix bug in diff_stats function
* The bug was caused by a dividing by numpy array
of all zeros. The previous way of handling it
did not consider the all zeros case but only
dealt with partially zeros
* unit tests added.
* Fix the bug in rewrite/core/graph_edit/join.py
* Remove the possibility of passing None to append_relabel
function because it is immutable
* The bug happened when empty dictionary was passed in the
append_relabel function and the function overwrites the
reference of operator_map which caused the dictionary
was not updated after the function call
Resolves: MLIA-749, MLIA-864, MLIA-866
Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests/test_cli_commands.py')
-rw-r--r-- | tests/test_cli_commands.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 6765a53..e4bbe91 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -73,8 +73,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", does_not_raise(), ], [ @@ -85,8 +85,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", pytest.raises( Exception, match=(r"Only 'rewrite' is supported for TensorFlow Lite files."), @@ -157,7 +157,7 @@ def test_performance_unknown_target( ], ], ) -def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments +def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments target_profile: str, sample_context: ExecutionContext, pruning: bool, @@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments expected_error: Any, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, - test_tflite_model: Path, + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - model_type = test_tflite_model if rewrite else test_keras_model + model_type = test_tflite_model_fp32 if rewrite else test_keras_model + data = test_tfrecord_fp32 if rewrite else None with expected_error: optimize( @@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments rewrite_target=rewrite_target, rewrite_start=rewrite_start, rewrite_end=rewrite_end, + dataset=data, ) |