diff options
author | Charles Xu <charles.xu@arm.com> | 2020-06-17 12:42:41 +0200 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | 70cc12126b664c490ffb02f0cedf34ab539100cc (patch) | |
tree | 117930d53a2b9b72babe5fdda760921cc7a51a09 /ethosu/vela | |
parent | c30f495dc013a73e371dd8053a0381e4707ab309 (diff) | |
download | ethos-u-vela-70cc12126b664c490ffb02f0cedf34ab539100cc.tar.gz |
MLBEDSW-2506: Swap broadcast input if applicable
Signed-off-by: Charles Xu <charles.xu@arm.com>
Change-Id: I6e8a97486aa2e1a21101f7cc32cd3024a376162a
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/pass_packing.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 4cfac33c..fff192d7 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -256,6 +256,20 @@ def pack_into_passes(nng, arch, verbose_packing=False): ofm_tensor = op.outputs[0] build_pass((op,), ofm_tensor) + def broadcast_input_check(ps): + if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape: + return + + if ps.inputs[0].shape == [] or ps.inputs[1].shape == []: + return + + for idx in range(len(ps.inputs[1].shape)): + if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1: + return + + ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0] + ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0] + def build_pass(start_ops_to_process, ofm_tensor=None): reverse_ops_list = [] curr_flags = PassFlags.Empty @@ -400,6 +414,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): # ElementWise operation, 2 IFMs if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops: + # Swap broadcast input if applicable + broadcast_input_check(ps) + ps.ifm_tensor = ps.inputs[0] if len(ps.inputs) == 1: |