aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 84f99160..1e5dbd4d 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -128,7 +128,14 @@ def test_constraint_quant_scale_inf():
def test_constraint_ofm_scale_too_small():
# Tests handling of OFM scale < 1e-38
shp = [1, 10, 20, 16]
- op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
+ op = testutil.create_elemwise_op(
+ Op.Mul,
+ "mul",
+ shp,
+ shp,
+ shp,
+ ofm_quant=testutil.default_quant_params(),
+ )
assert semantic_checker.is_operator_semantic_valid(op)
op.ofm.quantization.scale_f32 = 1e-43
assert not semantic_checker.is_operator_semantic_valid(op)
@@ -245,7 +252,12 @@ def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
def create_pad_op(
- in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
+ in_shape,
+ out_shape,
+ padding,
+ in_dtype=DataType.int8,
+ out_dtype=DataType.int8,
+ pad_dtype=DataType.int32,
):
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, in_dtype, "in")
@@ -259,7 +271,11 @@ def create_pad_op(
def test_constraint_pad_input_count():
# Incorrect number of input tensors (2)
- op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
+ op = create_pad_op(
+ in_shape=[1, 1, 1, 1],
+ out_shape=[1, 3, 3, 1],
+ padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+ )
assert semantic_checker.is_operator_semantic_valid(op)
op.add_input_tensor(op.inputs[0].clone())
assert not semantic_checker.is_operator_semantic_valid(op)