aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/pass_packing.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 4cfac33c..fff192d7 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -256,6 +256,20 @@ def pack_into_passes(nng, arch, verbose_packing=False):
ofm_tensor = op.outputs[0]
build_pass((op,), ofm_tensor)
+ def broadcast_input_check(ps):
+ if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape:
+ return
+
+ if ps.inputs[0].shape == [] or ps.inputs[1].shape == []:
+ return
+
+ for idx in range(len(ps.inputs[1].shape)):
+ if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1:
+ return
+
+ ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0]
+ ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0]
+
def build_pass(start_ops_to_process, ofm_tensor=None):
reverse_ops_list = []
curr_flags = PassFlags.Empty
@@ -400,6 +414,9 @@ def pack_into_passes(nng, arch, verbose_packing=False):
# ElementWise operation, 2 IFMs
if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
+ # Swap broadcast input if applicable
+ broadcast_input_check(ps)
+
ps.ifm_tensor = ps.inputs[0]
if len(ps.inputs) == 1: