diff options
-rw-r--r-- | ethosu/vela/supported_operators.py | 24 | ||||
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 36 |
2 files changed, 58 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 6dcb27d0..6bbb04b9 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -211,14 +211,17 @@ class SupportedOperators: for op_type in SupportedOperators.binary_elem_wise_min_max_ops: self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types) self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # Binary Add/Mul/Sub specific checks: for op_type in SupportedOperators.binary_elem_wise_add_mul_sub: self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types) self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed) self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # Binary Shift specific checks: for op_type in SupportedOperators.binary_elem_wise_shift_ops: self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # SHL specific checks: self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32) @@ -883,6 +886,27 @@ class SupportedOperators: return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}" @staticmethod + def constraint_broadcast_shapes(op): + "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2" + ifm_shape = op.ifm.shape + ifm2_shape = op.ifm2.shape if op.ifm2 else None + ofm_shape = op.ofm.shape + valid = True + if ifm_shape is not None and ifm2_shape is not None: + # align trailing dimensions + size = min(len(ifm_shape), len(ifm2_shape)) + for i, i2, o in zip(ifm_shape[-size:], ifm2_shape[-size:], ofm_shape[-size:]): + mi = max(i, i2) + # Input dimensions should match or one should be of dimension 1 + # Output dimension should match the largest input dimension, together + # with constraint_match_either_shapes ensures broadcast from only one input + if not (i == i2 or i == 1 or i2 == 1) or o != mi: + valid = False + break + + return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}" + + @staticmethod def constraint_alpha_valid(op): "Alpha must not be negative" alpha = op.attrs["alpha"] diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 86d24757..72ccad24 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -658,12 +658,16 @@ def test_constraint_elemwise_batch_size(): def test_constraint_matching_either_shapes(): # BINARY CASE # At least one ifm shape must match ofm's shape - op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4]) assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4]) assert support.is_operator_supported(op) op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2]) assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) # UNARY CASE # No second input so this is treated the same as requiring ifm shape to match ofm shape @@ -673,6 +677,34 @@ def test_constraint_matching_either_shapes(): assert not support.is_operator_supported(op) +def test_constraint_broadcast_shapes(): + # BINARY CASE + # Only allow broadcast to 1 dim, for 1 rank index + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + # Only allow broadcast to 1 dim, for 3 rank indexes + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + # One broadcast dim not 1 + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + # OFM shape dim largest ifm/ifm2 shape dim + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + + def test_constraint_alpha_valid(): # Alpha cannot be negative op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) |