diff options
Diffstat (limited to 'ethosu/vela/pass_packing.py')
-rw-r--r-- | ethosu/vela/pass_packing.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 4cfac33c..fff192d7 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -256,6 +256,20 @@ def pack_into_passes(nng, arch, verbose_packing=False): ofm_tensor = op.outputs[0] build_pass((op,), ofm_tensor) + def broadcast_input_check(ps): + if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape: + return + + if ps.inputs[0].shape == [] or ps.inputs[1].shape == []: + return + + for idx in range(len(ps.inputs[1].shape)): + if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1: + return + + ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0] + ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0] + def build_pass(start_ops_to_process, ofm_tensor=None): reverse_ops_list = [] curr_flags = PassFlags.Empty @@ -400,6 +414,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): # ElementWise operation, 2 IFMs if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops: + # Swap broadcast input if applicable + broadcast_input_check(ps) + ps.ifm_tensor = ps.inputs[0] if len(ps.inputs) == 1: |