diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 6f3553d8..f2ad8586 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -613,9 +613,26 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs): def test_mean_hw_product(): - op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {}) + # max kernel size checks + op = create_mean([1, 4096, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {}) assert support.is_operator_supported(op) - op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 4097, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {}) + assert not support.is_operator_supported(op) + + op = create_mean([1, 2048, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {}) + assert support.is_operator_supported(op) + op = create_mean([1, 2049, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {}) + assert not support.is_operator_supported(op) + + op = create_mean([1, 16, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {}) + assert support.is_operator_supported(op) + op = create_mean([1, 17, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {}) + assert not support.is_operator_supported(op) + + # h > 4096 is OK but w > 4096 is not + op = create_mean([1, 4097, 10, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {"keep_dims": True}) + assert support.is_operator_supported(op) + op = create_mean([1, 10, 4097, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {"keep_dims": True}) assert not support.is_operator_supported(op) |