aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/scheduler.py
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-21 11:21:38 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-25 15:31:08 +0200
commit0f2e59f8bff0ca68794db1406e1264531da1d3a5 (patch)
treeef3c4420dd5fe47c9cd5cc1cd6bb79082962b923 /ethosu/vela/scheduler.py
parent2a285fca30d13f6577ef3e8154aea24713d728a5 (diff)
downloadethos-u-vela-0f2e59f8bff0ca68794db1406e1264531da1d3a5.tar.gz
MLBEDSW-7028: Fix compiler assert for elementwise op
- Refactored erroneously if statement that allowed illegal swapping between ifm1 and ifm2 for elementwise operators. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: Iec571f710824432edac9104d960f199f33a1b241
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r--ethosu/vela/scheduler.py22
1 files changed, 10 insertions, 12 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index ec7380a6..021bcc9e 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -227,19 +227,17 @@ class SchedulerOperation:
# Perform an IFM swap for certain binary elementwise operators
# in order to enable cascading, if the SchedOp conforms to
# Elementwise cascading rules.
- if self.op_type.is_binary_elementwise_op() and CascadeBuilder.elementwise_cascading_conformity(self):
- ifm1 = ps.ifm_tensor
- ifm2 = ps.ifm2_tensor
- ofm = ps.ofm_tensor
- assert ifm1.elements() > 0
- assert ifm2.elements() > 0
+ # The non-constant/non-scalar/non-broadcast IFM should be the primary input
+ if self.op_type.is_binary_elementwise_op():
+ ifm = self.parent_op.ifm
+ ifm2 = self.parent_op.ifm2
+ ofm = self.parent_op.ofm
- if (
- # The non-constant IFM should be the primary input
- (ifm1.ops[0].type == Op.Const and ifm2.ops[0].type != Op.Const)
- # The non-broadcasted IFM should be the primary input
- or (ifm1.shape != ofm.shape and ifm2.shape == ofm.shape)
- ):
+ ifm_can_be_primary = not (ifm.is_const or ifm.is_scalar or ifm.is_broadcast(ofm))
+ ifm2_can_be_primary = not (ifm2.is_const or ifm2.is_scalar or ifm2.is_broadcast(ofm))
+
+ if not ifm_can_be_primary and ifm2_can_be_primary:
+ # IFM2 is the primary input
self.reversed_operands = True
self.ifm, self.ifm2 = self.ifm2, self.ifm