aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u/data_collection.py
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-04-20 09:51:20 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:44:51 +0100
commitf3e6597dd50ec70f043d692b773f2d9fd31519ae (patch)
tree322ccb75e0cc594c57308288cae333a72401979e /src/mlia/target/ethos_u/data_collection.py
parent867f37d643e66c0223457c28f5345f2f21db97f2 (diff)
downloadmlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/target/ethos_u/data_collection.py')
-rw-r--r--src/mlia/target/ethos_u/data_collection.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 0f3a8d2..ba8b0fe 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -201,7 +201,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
OptimizationSettings(
item.get("optimization_type"), # type: ignore
item.get("optimization_target"), # type: ignore
- item.get("layers_to_optimized"),
+ item.get("layers_to_optimize"),
+ item.get("dataset"),
)
for item in opt_configuration
]