aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 6dcb27d0..6bbb04b9 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -211,14 +211,17 @@ class SupportedOperators:
for op_type in SupportedOperators.binary_elem_wise_min_max_ops:
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# Binary Add/Mul/Sub specific checks:
for op_type in SupportedOperators.binary_elem_wise_add_mul_sub:
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types)
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed)
self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# Binary Shift specific checks:
for op_type in SupportedOperators.binary_elem_wise_shift_ops:
self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# SHL specific checks:
self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32)
@@ -883,6 +886,27 @@ class SupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
@staticmethod
+ def constraint_broadcast_shapes(op):
+ "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2"
+ ifm_shape = op.ifm.shape
+ ifm2_shape = op.ifm2.shape if op.ifm2 else None
+ ofm_shape = op.ofm.shape
+ valid = True
+ if ifm_shape is not None and ifm2_shape is not None:
+ # align trailing dimensions
+ size = min(len(ifm_shape), len(ifm2_shape))
+ for i, i2, o in zip(ifm_shape[-size:], ifm2_shape[-size:], ofm_shape[-size:]):
+ mi = max(i, i2)
+ # Input dimensions should match or one should be of dimension 1
+ # Output dimension should match the largest input dimension, together
+ # with constraint_match_either_shapes ensures broadcast from only one input
+ if not (i == i2 or i == 1 or i2 == 1) or o != mi:
+ valid = False
+ break
+
+ return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
+
+ @staticmethod
def constraint_alpha_valid(op):
"Alpha must not be negative"
alpha = op.attrs["alpha"]