diff options
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 36 |
1 files changed, 34 insertions, 2 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 86d24757..72ccad24 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -658,12 +658,16 @@ def test_constraint_elemwise_batch_size(): def test_constraint_matching_either_shapes(): # BINARY CASE # At least one ifm shape must match ofm's shape - op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4]) assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4]) assert support.is_operator_supported(op) op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2]) assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16]) + assert not support.is_operator_supported(op) # UNARY CASE # No second input so this is treated the same as requiring ifm shape to match ofm shape @@ -673,6 +677,34 @@ def test_constraint_matching_either_shapes(): assert not support.is_operator_supported(op) +def test_constraint_broadcast_shapes(): + # BINARY CASE + # Only allow broadcast to 1 dim, for 1 rank index + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4]) + assert support.is_operator_supported(op) + # Only allow broadcast to 1 dim, for 3 rank indexes + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16]) + assert support.is_operator_supported(op) + # One broadcast dim not 1 + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4]) + assert not support.is_operator_supported(op) + # OFM shape dim largest ifm/ifm2 shape dim + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16]) + assert not support.is_operator_supported(op) + + def test_constraint_alpha_valid(): # Alpha cannot be negative op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) |