aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py36
1 files changed, 34 insertions, 2 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 86d24757..72ccad24 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -658,12 +658,16 @@ def test_constraint_elemwise_batch_size():
def test_constraint_matching_either_shapes():
# BINARY CASE
# At least one ifm shape must match ofm's shape
- op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2])
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2])
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
assert support.is_operator_supported(op)
op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
+ assert not support.is_operator_supported(op)
# UNARY CASE
# No second input so this is treated the same as requiring ifm shape to match ofm shape
@@ -673,6 +677,34 @@ def test_constraint_matching_either_shapes():
assert not support.is_operator_supported(op)
+def test_constraint_broadcast_shapes():
+ # BINARY CASE
+ # Only allow broadcast to 1 dim, for 1 rank index
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
+ assert support.is_operator_supported(op)
+ # Only allow broadcast to 1 dim, for 3 rank indexes
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
+ assert support.is_operator_supported(op)
+ # One broadcast dim not 1
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
+ assert not support.is_operator_supported(op)
+ # OFM shape dim largest ifm/ifm2 shape dim
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
+ assert not support.is_operator_supported(op)
+
+
def test_constraint_alpha_valid():
# Alpha cannot be negative
op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])