aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Xu <charles.xu@arm.com>2020-06-17 12:42:41 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit70cc12126b664c490ffb02f0cedf34ab539100cc (patch)
tree117930d53a2b9b72babe5fdda760921cc7a51a09
parentc30f495dc013a73e371dd8053a0381e4707ab309 (diff)
downloadethos-u-vela-70cc12126b664c490ffb02f0cedf34ab539100cc.tar.gz
MLBEDSW-2506: Swap broadcast input if applicable
Signed-off-by: Charles Xu <charles.xu@arm.com> Change-Id: I6e8a97486aa2e1a21101f7cc32cd3024a376162a
-rw-r--r--ethosu/vela/pass_packing.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 4cfac33c..fff192d7 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -256,6 +256,20 @@ def pack_into_passes(nng, arch, verbose_packing=False):
ofm_tensor = op.outputs[0]
build_pass((op,), ofm_tensor)
+ def broadcast_input_check(ps):
+ if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape:
+ return
+
+ if ps.inputs[0].shape == [] or ps.inputs[1].shape == []:
+ return
+
+ for idx in range(len(ps.inputs[1].shape)):
+ if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1:
+ return
+
+ ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0]
+ ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0]
+
def build_pass(start_ops_to_process, ofm_tensor=None):
reverse_ops_list = []
curr_flags = PassFlags.Empty
@@ -400,6 +414,9 @@ def pack_into_passes(nng, arch, verbose_packing=False):
# ElementWise operation, 2 IFMs
if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
+ # Swap broadcast input if applicable
+ broadcast_input_check(ps)
+
ps.ifm_tensor = ps.inputs[0]
if len(ps.inputs) == 1: