diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-10-19 11:20:12 +0200 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-10-20 14:38:55 +0200 |
commit | 56a71b0108f43a1cb118b1e2fae902c31b2a9969 (patch) | |
tree | 4692139c4dcd6e53b55b1a07ff1b09eb8461a4a8 /ethosu/vela/cascade_builder.py | |
parent | fba0a7dc43373a69f3c0792587d3d9b0cc010ccf (diff) | |
download | ethos-u-vela-56a71b0108f43a1cb118b1e2fae902c31b2a9969.tar.gz |
MLBEDSW-7019: Update to elementwise cascading
- The cascade builder is using the ifm_ifm2_correct_order
function in order to decide if the operator is cascadable or not.
The problem is that this function expects a full shape or no shape
and the cascade builder did not provide that, so the operator was
reported to be non cascadable.
- The fix is to provide a full 4D shape, also refactoring
ifm_ifm2_correct_order to use 4D shape to avoid confusion
in the future.
- Refactoring code so that the scheduler can perform a
correct ifm and ifm2 swap.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I9a86c4690612f332afa428456a07e67698852495
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r-- | ethosu/vela/cascade_builder.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index 9c84ba8d..b042ba73 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -98,7 +98,7 @@ class CascadeBuilder: and cost.stripe.height < sched_op.ofm.shape.height and sched_op.parent_op.read_offsets[0] is None and sched_op.parent_op.read_offsets[1] is None - and self.element_wise_cascading_conformity(sched_op) + and self.elementwise_cascading_correct_order(sched_op) and not sched_op.parent_op.type.is_resize_op() and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE @@ -127,22 +127,34 @@ class CascadeBuilder: return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0) @staticmethod - def element_wise_cascading_conformity(sched_op): + def elementwise_cascading_conformity(sched_op): """Check the inputs of the op to see if it's a candidate for cascading.""" - ifm = sched_op.parent_op.ifm - ifm2 = sched_op.parent_op.ifm2 - - if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2: + if sched_op.parent_op.type.is_binary_elementwise_op(): # We cannot rule out cascadability if at least one IFM is constant + ifm = sched_op.parent_op.ifm + ifm2 = sched_op.parent_op.ifm2 ifm_const = ifm.ops != [] and ifm.ops[0].type == Op.Const ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const - correct_order = ifm_ifm2_correct_order(ifm.shape, ifm2.shape) - return (ifm_const and (ifm.shape == ifm2.shape or not correct_order)) or (ifm2_const and correct_order) + return ifm_const or ifm2_const else: # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability return True + @staticmethod + def elementwise_cascading_correct_order(sched_op): + """Check the inputs of the op to see ifm and ifm2 has correct order.""" + + if sched_op.parent_op.type.is_binary_elementwise_op(): + ifm2 = sched_op.parent_op.ifm2 + ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const + + # ifm_ifm2_correct_order needs full shape + correct_order = ifm_ifm2_correct_order(sched_op.ifm.shape, sched_op.ifm2.shape) + return ifm2_const and correct_order + else: + return True + def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit): ref_cost = ref_schedule.cost_map fallback_cost = fallback_schedule.cost_map |