diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-04-20 09:51:20 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:44:51 +0100 |
commit | f3e6597dd50ec70f043d692b773f2d9fd31519ae (patch) | |
tree | 322ccb75e0cc594c57308288cae333a72401979e /src/mlia/nn/rewrite/library | |
parent | 867f37d643e66c0223457c28f5345f2f21db97f2 (diff) | |
download | mlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz |
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer
* Define RewriteConfiguration and Rewriter to integrate
rewrite module into mlia optimize command
* Fix a bug in the ethos_u/data_collection.py file
* Fix a bug in join.py
* Remove diff_stats and use diff instead, added related
changes around this to ensure e2e tests passing
* Add unit tests for all changes
* Fix bug in diff_stats function
* The bug was caused by a dividing by numpy array
of all zeros. The previous way of handling it
did not consider the all zeros case but only
dealt with partially zeros
* unit tests added.
* Fix the bug in rewrite/core/graph_edit/join.py
* Remove the possibility of passing None to append_relabel
function because it is immutable
* The bug happened when empty dictionary was passed in the
append_relabel function and the function overwrites the
reference of operator_map which caused the dictionary
was not updated after the function call
Resolves: MLIA-749, MLIA-864, MLIA-866
Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/library')
-rw-r--r-- | src/mlia/nn/rewrite/library/__init__.py | 3 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_layer.py | 18 |
2 files changed, 21 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/library/__init__.py b/src/mlia/nn/rewrite/library/__init__.py new file mode 100644 index 0000000..2988554 --- /dev/null +++ b/src/mlia/nn/rewrite/library/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions as library.""" diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py new file mode 100644 index 0000000..8704154 --- /dev/null +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Example rewrite with one fully connected layer.""" +from typing import Any + +import tensorflow as tf + + +def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model: + """Generate tflite model for rewrite.""" + input_tensor = tf.keras.layers.Input( + shape=input_shape, name="MbileNet/avg_pool/AvgPool" + ) + output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")( + input_tensor + ) + model = tf.keras.Model(input_tensor, output_tensor) + return model |