diff options
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r-- | ethosu/vela/cascade_builder.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index e7105e2c..09c36b9e 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -18,12 +18,12 @@ # Groups Operators in a schedule together to form Cascades. from .numeric_util import round_up from .operation import NpuBlockType +from .operation import Op from .shape4d import Shape4D non_cascadable_blocks = ( NpuBlockType.Default, NpuBlockType.VectorProduct, - NpuBlockType.ElementWise, NpuBlockType.ReduceSum, ) @@ -89,11 +89,13 @@ class CascadeBuilder: def _is_cascadable(self, sched_op, cost) -> bool: """Checks if 'sched_op' can be cascaded""" + return ( sched_op.op_type.npu_block_type not in non_cascadable_blocks and cost.stripe.height < sched_op.ofm.shape.height and sched_op.parent_op.read_offsets[0] is None and sched_op.parent_op.read_offsets[1] is None + and self.element_wise_cascading_conformity(sched_op) ) def _estimate_sram_usage(self, sched_op, cost) -> int: @@ -115,6 +117,24 @@ class CascadeBuilder: return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0) + @staticmethod + def element_wise_cascading_conformity(sched_op): + """Check the inputs of the op to see if it's a candidate for cascading.""" + # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops + + ifm = sched_op.parent_op.ifm + ifm2 = sched_op.parent_op.ifm2 + + if sched_op.op_type in [Op.RescaleAdd]: + return False + + if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2: + # We cannot rule out cascadability if at least one IFM is constant + return Op.Const in (ifm.ops[0], ifm2.ops[0]) + else: + # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability + return True + def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit): ref_cost = ref_schedule.cost_map fallback_cost = fallback_schedule.cost_map @@ -260,7 +280,7 @@ class CascadeBuilder: if not self.spilling: peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage) - # Update costing and cascde information for the ref_schedule + # Update costing and cascade information for the ref_schedule ref_schedule.cost_map = cost ref_schedule.cascades = cascade_map return ref_schedule |