From d059d8b64fedde37137b5da4d5e3a082154e954d Mon Sep 17 00:00:00 2001 From: Andreas Nevalainen Date: Thu, 19 Nov 2020 14:40:35 +0100 Subject: MLBEDSW-3157: Add test for broadcast shapes Change-Id: Ifbd6c053ac618bedce0f56fe5c4c647a71d9cc46 Signed-off-by: Andreas Nevalainen --- ethosu/vela/supported_operators.py | 24 +++++++++++++++++++ ethosu/vela/test/test_supported_operators.py | 36 ++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 6dcb27d0..6bbb04b9 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -211,14 +211,17 @@ class SupportedOperators: for op_type in SupportedOperators.binary_elem_wise_min_max_ops: self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types) self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # Binary Add/Mul/Sub specific checks: for op_type in SupportedOperators.binary_elem_wise_add_mul_sub: self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types) self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed) self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # Binary Shift specific checks: for op_type in SupportedOperators.binary_elem_wise_shift_ops: self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32) + self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes) # SHL specific checks: self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32) @@ -882,6 +885,27 @@ class SupportedOperators: valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape) return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}" + @staticmethod + def constraint_broadcast_shapes(op): + "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2" + ifm_shape = op.ifm.shape + ifm2_shape = op.ifm2.shape if op.ifm2 else None + ofm_shape = op.ofm.shape + valid = True + if ifm_shape is not None and ifm2_shape is not None: + # align trailing dimensions + size = min(len(ifm_shape), len(ifm2_shape)) + for i, i2, o in zip(ifm_shape[-size:], ifm2_shape[-size:], ofm_shape[-size:]): + mi = max(i, i2) + # Input dimensions should match or one should be of dimension 1 + # Output dimension should match the largest input dimension, together + # with constraint_match_either_shapes ensures broadcast from only one input + if not (i == i2 or i == 1 or i2 == 1) or o != mi: + valid = False + break + + return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}" + @staticmethod def constraint_alpha_valid(op): "Alpha must not be negative" diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 86d24757..72ccad24 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -658,12 +658,16 @@ def test_constraint_elemwise_batch_size(): def test_constraint_matching_either_shapes(): # BINARY CASE # At least one ifm shape must match ofm's shape - op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4]) assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4]) assert support.is_operator_supported(op) op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2]) assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) # UNARY CASE # No second input so this is treated the same as requiring ifm shape to match ofm shape @@ -673,6 +677,34 @@ def test_constraint_matching_either_shapes(): assert not support.is_operator_supported(op) +def test_constraint_broadcast_shapes(): + # BINARY CASE + # Only allow broadcast to 1 dim, for 1 rank index + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + # Only allow broadcast to 1 dim, for 3 rank indexes + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + # One broadcast dim not 1 + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + # OFM shape dim largest ifm/ifm2 shape dim + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + + def test_constraint_alpha_valid(): # Alpha cannot be negative op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) -- cgit v1.2.1