aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-19 11:20:12 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-20 14:38:55 +0200
commit56a71b0108f43a1cb118b1e2fae902c31b2a9969 (patch)
tree4692139c4dcd6e53b55b1a07ff1b09eb8461a4a8
parentfba0a7dc43373a69f3c0792587d3d9b0cc010ccf (diff)
downloadethos-u-vela-56a71b0108f43a1cb118b1e2fae902c31b2a9969.tar.gz
MLBEDSW-7019: Update to elementwise cascading
- The cascade builder is using the ifm_ifm2_correct_order function in order to decide if the operator is cascadable or not. The problem is that this function expects a full shape or no shape and the cascade builder did not provide that, so the operator was reported to be non cascadable. - The fix is to provide a full 4D shape, also refactoring ifm_ifm2_correct_order to use 4D shape to avoid confusion in the future. - Refactoring code so that the scheduler can perform a correct ifm and ifm2 swap. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I9a86c4690612f332afa428456a07e67698852495
-rw-r--r--ethosu/vela/cascade_builder.py28
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py13
-rw-r--r--ethosu/vela/operation_util.py4
-rw-r--r--ethosu/vela/scheduler.py2
4 files changed, 30 insertions, 17 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 9c84ba8d..b042ba73 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -98,7 +98,7 @@ class CascadeBuilder:
and cost.stripe.height < sched_op.ofm.shape.height
and sched_op.parent_op.read_offsets[0] is None
and sched_op.parent_op.read_offsets[1] is None
- and self.element_wise_cascading_conformity(sched_op)
+ and self.elementwise_cascading_correct_order(sched_op)
and not sched_op.parent_op.type.is_resize_op()
and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias
and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
@@ -127,22 +127,34 @@ class CascadeBuilder:
return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
@staticmethod
- def element_wise_cascading_conformity(sched_op):
+ def elementwise_cascading_conformity(sched_op):
"""Check the inputs of the op to see if it's a candidate for cascading."""
- ifm = sched_op.parent_op.ifm
- ifm2 = sched_op.parent_op.ifm2
-
- if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
+ if sched_op.parent_op.type.is_binary_elementwise_op():
# We cannot rule out cascadability if at least one IFM is constant
+ ifm = sched_op.parent_op.ifm
+ ifm2 = sched_op.parent_op.ifm2
ifm_const = ifm.ops != [] and ifm.ops[0].type == Op.Const
ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
- correct_order = ifm_ifm2_correct_order(ifm.shape, ifm2.shape)
- return (ifm_const and (ifm.shape == ifm2.shape or not correct_order)) or (ifm2_const and correct_order)
+ return ifm_const or ifm2_const
else:
# Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
return True
+ @staticmethod
+ def elementwise_cascading_correct_order(sched_op):
+ """Check the inputs of the op to see ifm and ifm2 has correct order."""
+
+ if sched_op.parent_op.type.is_binary_elementwise_op():
+ ifm2 = sched_op.parent_op.ifm2
+ ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
+
+ # ifm_ifm2_correct_order needs full shape
+ correct_order = ifm_ifm2_correct_order(sched_op.ifm.shape, sched_op.ifm2.shape)
+ return ifm2_const and correct_order
+ else:
+ return True
+
def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
ref_cost = ref_schedule.cost_map
fallback_cost = fallback_schedule.cost_map
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 202917bd..228c76f8 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -112,14 +112,15 @@ resampling_mode_inv_map = {
}
-def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
- if ifm_shape == []:
+def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool:
+
+ if ifm_shape is None:
# Scalar needs to be in IFM2
return False
- if ifm2_shape == []:
+ if ifm2_shape is None:
return True
- for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
+ for ifm, ifm2 in zip(ifm_shape.as_list(), ifm2_shape.as_list()):
if ifm != ifm2 and ifm == 1:
# Broadcasted FM needs to be in IFM2
return False
@@ -553,8 +554,8 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
npu_op = NpuElementWiseOperation(elemwise_op)
if elemwise_op not in UNARY_ELEMWISE_OPS:
- ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
- ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
+ ifm_shape = None if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0]
+ ifm2_shape = None if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1]
if cmd.reversed_operands:
assert ifm_ifm2_correct_order(ifm_shape, ifm2_shape)
npu_op.reversed_operands = True
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index aaabddbf..c4176d96 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -242,8 +242,8 @@ def create_binary_elementwise(
if ifm2 is None:
ofm_shape = ifm_shape
else:
- in_shape = [] if ifm.shape == [] else ifm_shape.as_list()
- in2_shape = [] if ifm2.shape == [] else ifm2_shape.as_list()
+ in_shape = None if ifm.shape == [] else ifm_shape
+ in2_shape = None if ifm2.shape == [] else ifm2_shape
ofm_shape = ifm_shape if ifm_ifm2_correct_order(in_shape, in2_shape) else ifm2_shape
ofm = Tensor(ofm_shape.as_list(), dtype, f"{op.name}_tens0")
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 208b121e..79cd6421 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -227,7 +227,7 @@ class SchedulerOperation:
# Perform an IFM swap for certain binary elementwise operators
# in order to enable cascading, if the SchedOp conforms to
# Elementwise cascading rules.
- if self.op_type.is_binary_elementwise_op() and CascadeBuilder.element_wise_cascading_conformity(self):
+ if self.op_type.is_binary_elementwise_op() and CascadeBuilder.elementwise_cascading_conformity(self):
ifm1 = ps.ifm_tensor
ifm2 = ps.ifm2_tensor
ofm = ps.ofm_tensor