aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 1e5dbd4d..2d6ca15a 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -81,11 +81,13 @@ def test_constraint_tens_quant_scale():
def test_constraint_fc_output_2d_not_supp():
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1])
assert not semantic_checker.is_operator_semantic_valid(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1])
assert not semantic_checker.is_operator_semantic_valid(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2])
assert not semantic_checker.is_operator_semantic_valid(op)
@@ -94,6 +96,20 @@ def test_constraint_fc_output_2d_is_supp():
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(
+ Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1]
+ )
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1])
+ assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_conv_pass():