diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index c242063d..2e0936d0 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -195,11 +195,11 @@ def test_constraint_splitv_inferred(): # SplitV requires a maximum of one inferred shape (-1) qp = testutil.default_quant_params() op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) - sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp) + sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], quantization=qp) op.add_input_tensor(sizes) assert not semantic_checker.is_operator_semantic_valid(op) op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) - sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp) + sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], quantization=qp) op.add_input_tensor(sizes) assert semantic_checker.is_operator_semantic_valid(op) @@ -278,7 +278,8 @@ def create_pad_op( qp = testutil.default_quant_params() in0 = Tensor(in_shape, in_dtype, "in") in0.quantization = qp - pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype) + shape = [] if padding == [] else list(np.shape(padding)) + pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype) out = Tensor(out_shape, out_dtype, "out") out.quantization = qp.clone() op = testutil.create_op(Op.Pad, [in0, pad_tensor], out) @@ -449,9 +450,9 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs): ofm = Tensor(output_shape, datatype, "out") ofm.quantization = testutil.default_quant_params() if type(axis) is list: - indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8) + indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis) elif type(axis) is int: - indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8) + indices = create_const_tensor("indices", [], DataType.int32, axis) op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs) return op |