aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
authorerik.andersson@arm.com <erik.andersson@arm.com>2022-03-22 15:35:30 +0100
committererik.andersson@arm.com <erik.andersson@arm.com>2022-07-11 11:27:47 +0200
commit6b2a0b4a64d01c8b038050a87c29f38a4909515c (patch)
tree0bd213a78debbfbe8465fcbf1c87eadd1f44fc2f /ethosu/vela/cascade_builder.py
parent25f48dd70aebeecd490de71eed3d4f7fbad1b121 (diff)
downloadethos-u-vela-6b2a0b4a64d01c8b038050a87c29f38a4909515c.tar.gz
MLBEDSW-6261: Elementwise cascading
Enabled elementwise cascading for binary/single variable IFM operators. Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com> Change-Id: I1c0867875fdc5c4980224fb570185c11e719d5cd
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py24
1 files changed, 22 insertions, 2 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index e7105e2c..09c36b9e 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -18,12 +18,12 @@
# Groups Operators in a schedule together to form Cascades.
from .numeric_util import round_up
from .operation import NpuBlockType
+from .operation import Op
from .shape4d import Shape4D
non_cascadable_blocks = (
NpuBlockType.Default,
NpuBlockType.VectorProduct,
- NpuBlockType.ElementWise,
NpuBlockType.ReduceSum,
)
@@ -89,11 +89,13 @@ class CascadeBuilder:
def _is_cascadable(self, sched_op, cost) -> bool:
"""Checks if 'sched_op' can be cascaded"""
+
return (
sched_op.op_type.npu_block_type not in non_cascadable_blocks
and cost.stripe.height < sched_op.ofm.shape.height
and sched_op.parent_op.read_offsets[0] is None
and sched_op.parent_op.read_offsets[1] is None
+ and self.element_wise_cascading_conformity(sched_op)
)
def _estimate_sram_usage(self, sched_op, cost) -> int:
@@ -115,6 +117,24 @@ class CascadeBuilder:
return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
+ @staticmethod
+ def element_wise_cascading_conformity(sched_op):
+ """Check the inputs of the op to see if it's a candidate for cascading."""
+ # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops
+
+ ifm = sched_op.parent_op.ifm
+ ifm2 = sched_op.parent_op.ifm2
+
+ if sched_op.op_type in [Op.RescaleAdd]:
+ return False
+
+ if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
+ # We cannot rule out cascadability if at least one IFM is constant
+ return Op.Const in (ifm.ops[0], ifm2.ops[0])
+ else:
+ # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
+ return True
+
def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
ref_cost = ref_schedule.cost_map
fallback_cost = fallback_schedule.cost_map
@@ -260,7 +280,7 @@ class CascadeBuilder:
if not self.spilling:
peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
- # Update costing and cascde information for the ref_schedule
+ # Update costing and cascade information for the ref_schedule
ref_schedule.cost_map = cost
ref_schedule.cascades = cascade_map
return ref_schedule