aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 6401d29..832d60f 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -945,3 +945,11 @@ def test_constraint_hardswish_dtype():
out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
assert not support.is_operator_supported(op)
+
+
+def test_constraint_keep_dims_ifm_ofm():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ op.attrs["keep_num_dims"] = True
+ assert not support.is_operator_supported(op)
+ op.attrs["keep_num_dims"] = False
+ assert support.is_operator_supported(op)