aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Nevalainen <andreas.nevalainen@arm.com>2020-11-19 14:40:35 +0100
committertim.hall <tim.hall@arm.com>2020-11-20 14:32:27 +0000
commitd059d8b64fedde37137b5da4d5e3a082154e954d (patch)
tree85e9c6dc6fea14aaeb1252c97c3ba3293b716a6a
parent9dce04c3f5c12da360185e852a9524647a4c6272 (diff)
downloadethos-u-vela-d059d8b64fedde37137b5da4d5e3a082154e954d.tar.gz
MLBEDSW-3157: Add test for broadcast shapes
Change-Id: Ifbd6c053ac618bedce0f56fe5c4c647a71d9cc46 Signed-off-by: Andreas Nevalainen <andreas.nevalainen@arm.com>
-rw-r--r--ethosu/vela/supported_operators.py24
-rw-r--r--ethosu/vela/test/test_supported_operators.py36
2 files changed, 58 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 6dcb27d..6bbb04b 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -211,14 +211,17 @@ class SupportedOperators:
for op_type in SupportedOperators.binary_elem_wise_min_max_ops:
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# Binary Add/Mul/Sub specific checks:
for op_type in SupportedOperators.binary_elem_wise_add_mul_sub:
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types)
self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed)
self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# Binary Shift specific checks:
for op_type in SupportedOperators.binary_elem_wise_shift_ops:
self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
# SHL specific checks:
self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32)
@@ -883,6 +886,27 @@ class SupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
@staticmethod
+ def constraint_broadcast_shapes(op):
+ "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2"
+ ifm_shape = op.ifm.shape
+ ifm2_shape = op.ifm2.shape if op.ifm2 else None
+ ofm_shape = op.ofm.shape
+ valid = True
+ if ifm_shape is not None and ifm2_shape is not None:
+ # align trailing dimensions
+ size = min(len(ifm_shape), len(ifm2_shape))
+ for i, i2, o in zip(ifm_shape[-size:], ifm2_shape[-size:], ofm_shape[-size:]):
+ mi = max(i, i2)
+ # Input dimensions should match or one should be of dimension 1
+ # Output dimension should match the largest input dimension, together
+ # with constraint_match_either_shapes ensures broadcast from only one input
+ if not (i == i2 or i == 1 or i2 == 1) or o != mi:
+ valid = False
+ break
+
+ return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
+
+ @staticmethod
def constraint_alpha_valid(op):
"Alpha must not be negative"
alpha = op.attrs["alpha"]
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 86d2475..72ccad2 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -658,12 +658,16 @@ def test_constraint_elemwise_batch_size():
def test_constraint_matching_either_shapes():
# BINARY CASE
# At least one ifm shape must match ofm's shape
- op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2])
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2])
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
assert support.is_operator_supported(op)
op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
+ assert not support.is_operator_supported(op)
# UNARY CASE
# No second input so this is treated the same as requiring ifm shape to match ofm shape
@@ -673,6 +677,34 @@ def test_constraint_matching_either_shapes():
assert not support.is_operator_supported(op)
+def test_constraint_broadcast_shapes():
+ # BINARY CASE
+ # Only allow broadcast to 1 dim, for 1 rank index
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
+ assert support.is_operator_supported(op)
+ # Only allow broadcast to 1 dim, for 3 rank indexes
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
+ assert support.is_operator_supported(op)
+ # One broadcast dim not 1
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
+ assert not support.is_operator_supported(op)
+ # OFM shape dim largest ifm/ifm2 shape dim
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
+ assert not support.is_operator_supported(op)
+
+
def test_constraint_alpha_valid():
# Alpha cannot be negative
op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])