From 70cc12126b664c490ffb02f0cedf34ab539100cc Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Wed, 17 Jun 2020 12:42:41 +0200 Subject: MLBEDSW-2506: Swap broadcast input if applicable Signed-off-by: Charles Xu Change-Id: I6e8a97486aa2e1a21101f7cc32cd3024a376162a --- ethosu/vela/pass_packing.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'ethosu') diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 4cfac33c..fff192d7 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -256,6 +256,20 @@ def pack_into_passes(nng, arch, verbose_packing=False): ofm_tensor = op.outputs[0] build_pass((op,), ofm_tensor) + def broadcast_input_check(ps): + if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape: + return + + if ps.inputs[0].shape == [] or ps.inputs[1].shape == []: + return + + for idx in range(len(ps.inputs[1].shape)): + if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1: + return + + ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0] + ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0] + def build_pass(start_ops_to_process, ofm_tensor=None): reverse_ops_list = [] curr_flags = PassFlags.Empty @@ -400,6 +414,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): # ElementWise operation, 2 IFMs if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops: + # Swap broadcast input if applicable + broadcast_input_check(ps) + ps.ifm_tensor = ps.inputs[0] if len(ps.inputs) == 1: -- cgit v1.2.1