diff options
author | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2021-07-26 16:13:12 +0200 |
---|---|---|
committer | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2021-07-27 11:06:27 +0200 |
commit | 45e653dbd81633b8d78215b16a9b2205e39dd8e2 (patch) | |
tree | 18b3073eac45e9e8d69a616ae96d7a3fbdef9663 /ethosu/vela/test | |
parent | c2449827ec55f49b6087e3e385fb3c4f6776dc6a (diff) | |
download | ethos-u-vela-45e653dbd81633b8d78215b16a9b2205e39dd8e2.tar.gz |
MLBEDSW-4853: Refactor supported operators
Refactor supported operators by breaking out model semantics
into its own class. Model semantics checked right after model
read.
Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com>
Change-Id: If442b189efcd91dda01af60b2b3adedfacdf2fad
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 460 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py (renamed from ethosu/vela/test/test_supported_operators.py) | 408 |
2 files changed, 485 insertions, 383 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py new file mode 100644 index 00000000..4c329844 --- /dev/null +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -0,0 +1,460 @@ +# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Unit tests for tflite_model_semantic +import numpy as np + +from ethosu.vela.data_type import DataType +from ethosu.vela.operation import Op +from ethosu.vela.operation import Padding +from ethosu.vela.tensor import create_const_tensor +from ethosu.vela.tensor import QuantizationParameters +from ethosu.vela.tensor import Tensor +from ethosu.vela.test import testutil +from ethosu.vela.tflite_model_semantic import TFLiteSemantic + +semantic_checker = TFLiteSemantic() + + +def test_constraint_tens_no_dynamic(): + # Tensors cannot be dynamic (no shape, not a scalar) + op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], []) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_defined_shape(): + # Tensors cannot have None in them + op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8]) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_output_scalar(): + # Scalar output is not allowed at all: + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], []) + op.ofm.values = 0.5 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_input_scalar(): + # Shapeless input is allowed if its of a certain type: + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8]) + assert semantic_checker.is_operator_semantic_valid(op) + # Invalid shapeless input due to op type: + op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8]) + op.ifm.values = 0.5 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_shape_size(): + # Tensors cannot be > 4D + op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_quant_none_check(): + # Tensors must have quantization parameters + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_tens_quant_scale(): + # Quantization scale cannot be infinite + qp = QuantizationParameters() + qp.zero_point = 0 + qp.scale_f32 = np.inf + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_fc_output_2d_not_supp(): + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1]) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_fc_output_2d_is_supp(): + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024]) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_conv_pass(): + # First test a simple conv passes + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]) + op.attrs = {"stride_w": 1, "stride_h": 1} + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_stride_type(): + # Stride width and height must be integer types + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op.attrs = {"stride_w": 1.5, "stride_h": "1"} + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_dilation_type(): + # Dilation width and height must be integer types + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"} + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_quant_scale_inf(): + # Test handling IFM scale/OFM scale is infinite + op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8]) + op.ifm.quantization.scale_f32 = np.float32(1e9) + op.ofm.quantization.scale_f32 = np.float32(1e-35) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_ofm_scale_too_small(): + # Tests handling of OFM scale < 1e-38 + shp = [1, 10, 20, 16] + op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),) + assert semantic_checker.is_operator_semantic_valid(op) + op.ofm.quantization.scale_f32 = 1e-43 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_in_out_types(): + # Valid + op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8]) + op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME} + assert semantic_checker.is_operator_semantic_valid(op) + # Invalid. datatypes for ifm and ofm must match (default uint8) + op.ifm.dtype = DataType.int8 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_filter_type(): + # Filter width/height must be integers + op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8]) + op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME} + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_shapes(): + # Softmax requires the ifm and ofm shapes to match + op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8]) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_beta_value_range(): + # beta must be positive + op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8]) + op.attrs["beta"] = -1.0 + assert not semantic_checker.is_operator_semantic_valid(op) + op.attrs["beta"] = 0.0 + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_splitv_inferred(): + # SplitV requires a maximum of one inferred shape (-1) + qp = testutil.default_quant_params() + op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) + sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp) + op.add_input_tensor(sizes) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) + sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp) + op.add_input_tensor(sizes) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_concat_pass(): + # A working concat + op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8]) + ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2") + ifm2.quantization = testutil.default_quant_params() + op.add_input_tensor(ifm2) + op.attrs["axis"] = 3 + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_axis_exists(): + # Missing axis attribute + op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8]) + ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2") + ifm2.quantization = testutil.default_quant_params() + op.add_input_tensor(ifm2) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_axis_valid(): + # Invalid axis attribute + op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8]) + ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2") + ifm2.quantization = testutil.default_quant_params() + op.add_input_tensor(ifm2) + op.attrs["axis"] = 7 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_dimensionality(): + # Mismatching dimensionality: 4D+2D=4D + op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8]) + ifm2 = Tensor([1, 4], DataType.uint8, "in2") + ifm2.quantization = testutil.default_quant_params() + op.add_input_tensor(ifm2) + op.attrs["axis"] = 3 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_valid_dimensions(): + # Mismatching dimension value: + # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm + op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8]) + ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2") + ifm2.quantization = testutil.default_quant_params() + op.add_input_tensor(ifm2) + op.attrs["axis"] = 3 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets): + qp = testutil.default_quant_params() + in0 = Tensor(in_shape, DataType.uint8, "in") + in0.quantization = qp + in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp) + in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp) + in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp) + out = Tensor(out_shape, DataType.uint8, "out") + out.quantization = qp + attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0} + return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs) + + +def create_pad_op( + in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32, +): + qp = testutil.default_quant_params() + in0 = Tensor(in_shape, in_dtype, "in") + in0.quantization = qp + pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype) + out = Tensor(out_shape, out_dtype, "out") + out.quantization = qp.clone() + op = testutil.create_op(Op.Pad, [in0, pad_tensor], out) + return op + + +def test_constraint_pad_input_count(): + # Incorrect number of input tensors (2) + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],) + assert semantic_checker.is_operator_semantic_valid(op) + op.add_input_tensor(op.inputs[0].clone()) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def create_strided_slice(): + # Creates a valid strided slice operator with some valid inputs/outputs + op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0]) + op.attrs["begin_mask"] = 1 + op.attrs["end_mask"] = 9 + assert semantic_checker.is_operator_semantic_valid(op) + return op + + +def test_constraint_stridedslice_input_count(): + # Wrong number of input tensors + op = create_strided_slice() + op.add_input_tensor(op.inputs[0].clone()) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_stridedslice_inputs_const(): + # begin, end, stride values must not be None + op = create_strided_slice() + op.inputs[1].values = None + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.inputs[2].values = None + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.inputs[3].values = None + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_ellipsis_mask(): + # Unsemantic_checkered ellipsis mask + op = create_strided_slice() + op.attrs["ellipsis_mask"] = 1 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_axis_masks(): + op = create_strided_slice() + # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok + op.attrs["new_axis_mask"] = 2 + assert semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.attrs["shrink_axis_mask"] = 3 + assert semantic_checker.is_operator_semantic_valid(op) + # But setting both to non-zero is not semantic_checkered + op.attrs["new_axis_mask"] = 2 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_slice_ranges(): + # Examples where end offset <= begin offset + op = create_strided_slice() + op.inputs[1].values = [0, 7, 2, 0] + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.inputs[2].values = [0, 7, 2, 0] + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.attrs["begin_mask"] = 0 + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_strided_slice() + op.attrs["end_mask"] = 0 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_inputs_types(): + # input data types must match (default is uint8) + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8]) + op.ifm2.dtype = DataType.int8 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_signed(): + # signed inputs require output to also be signed + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) + op.ofm.dtype = DataType.uint8 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_unsigned_valid(): + # unsigned inputs require output to be either: + op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8]) + # the same (default uint8) + assert semantic_checker.is_operator_semantic_valid(op) + op.ofm.dtype = DataType.int8 + assert not semantic_checker.is_operator_semantic_valid(op) + op.ofm.dtype = DataType.int16 + assert not semantic_checker.is_operator_semantic_valid(op) + # or int32 + op.ofm.dtype = DataType.int32 + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_matching_either_shapes(): + # BINARY CASE + # At least one ifm shape must match ofm's shape + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16]) + assert not semantic_checker.is_operator_semantic_valid(op) + + # UNARY CASE + # No second input so this is treated the same as requiring ifm shape to match ofm shape + op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_alpha_valid(): + # Alpha cannot be negative + op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) + op.attrs["alpha"] = 0 + assert semantic_checker.is_operator_semantic_valid(op) + op.attrs["alpha"] = -1 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_hardswish_dtype(): + # HardSwish operator dtype should be int8 or uint8, and input dtype must match output + # UINT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8]) + assert semantic_checker.is_operator_semantic_valid(op) + # INT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) + assert semantic_checker.is_operator_semantic_valid(op) + + # Invalid + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32) + assert not semantic_checker.is_operator_semantic_valid(op) + + in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in") + out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out") + op = testutil.create_op(Op.HardSwish, [in_tens], out_tens) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_keep_dims_ifm_ofm(): + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4]) + op.attrs["keep_num_dims"] = True + assert not semantic_checker.is_operator_semantic_valid(op) + op.attrs["keep_num_dims"] = False + assert semantic_checker.is_operator_semantic_valid(op) + + +def create_mean(input_shape, output_shape, axis, datatype, attrs): + ifm = Tensor(input_shape, datatype, "in") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor(output_shape, datatype, "out") + ofm.quantization = testutil.default_quant_params() + if type(axis) is list: + indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8) + elif type(axis) is int: + indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8) + op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs) + return op + + +def test_mean_dtype(): + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op.ifm.dtype = DataType.int16 + op.ofm.dtype = DataType.int16 + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_mean_axis(): + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True}) + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True}) + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 38308154..af5dc174 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -15,55 +15,20 @@ # limitations under the License. # # Description: -# Unit tests for support_operators +# Unit tests for tflite support_operators import numpy as np from ethosu.vela.data_type import DataType from ethosu.vela.operation import ActivationFunction from ethosu.vela.operation import Op from ethosu.vela.operation import Padding -from ethosu.vela.supported_operators import SupportedOperators from ethosu.vela.tensor import create_const_tensor from ethosu.vela.tensor import QuantizationParameters from ethosu.vela.tensor import Tensor from ethosu.vela.test import testutil +from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators -support = SupportedOperators() - - -def test_constraint_tens_no_dynamic(): - # Tensors cannot be dynamic (no shape, not a scalar) - op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], []) - assert not support.is_operator_supported(op) - - -def test_constraint_tens_defined_shape(): - # Tensors cannot have None in them - op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8]) - assert not support.is_operator_supported(op) - - -def test_constraint_tens_output_scalar(): - # Scalar output is not allowed at all: - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], []) - op.ofm.values = 0.5 - assert not support.is_operator_supported(op) - - -def test_constraint_tens_input_scalar(): - # Shapeless input is allowed if its of a certain type: - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8]) - assert support.is_operator_supported(op) - # Invalid shapeless input due to op type: - op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8]) - op.ifm.values = 0.5 - assert not support.is_operator_supported(op) - - -def test_constraint_tens_shape_size(): - # Tensors cannot be > 4D - op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False) - assert not support.is_operator_supported(op) +support = TFLiteSupportedOperators() def test_constraint_tens_dtype(): @@ -86,21 +51,6 @@ def test_constraint_tens_dimension(): assert not support.is_operator_supported(op) -def test_constraint_tens_quant_none_check(): - # Tensors must have quantization parameters - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None) - assert not support.is_operator_supported(op) - - -def test_constraint_tens_quant_scale(): - # Quantization scale cannot be infinite - qp = QuantizationParameters() - qp.zero_point = 0 - qp.scale_f32 = np.inf - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp) - assert not support.is_operator_supported(op) - - def test_constraint_tens_quant_per_axis_not_supp(): # Quantization scale cannot be array-valued for elemwise ops qp = QuantizationParameters() @@ -123,15 +73,6 @@ def test_constraint_tens_quant_per_axis_is_supp(): assert support.is_operator_supported(op) -def test_constraint_fc_output_2d_not_supp(): - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1]) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1]) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1]) - assert not support.is_operator_supported(op) - - def test_constraint_fc_output_2d_is_supp(): op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4]) assert support.is_operator_supported(op) @@ -158,49 +99,35 @@ def test_constraint_faf_ofm_dtype(): def test_constraint_conv_pass(): # First test a simple conv passes - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1} assert support.is_operator_supported(op) -def test_constraint_stride_type(): - # Stride width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8]) - op.attrs = {"stride_w": 1.5, "stride_h": "1"} - assert not support.is_operator_supported(op) - - def test_constraint_stride_range(): # Stride width and height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) op.attrs = {"stride_w": 0, "stride_h": 20} assert not support.is_operator_supported(op) -def test_constraint_dilation_type(): - # Dilation width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8]) - op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"} - assert not support.is_operator_supported(op) - - def test_constraint_dilation_range(): # Dilation width and height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20} assert not support.is_operator_supported(op) def test_constraint_dilated_height_range(): # Dilated kernel height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1} assert not support.is_operator_supported(op) def test_constraint_dilated_product_range(): # Dilated kernel width x height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1} assert not support.is_operator_supported(op) @@ -208,7 +135,7 @@ def test_constraint_dilated_product_range(): def test_constraint_weights_type(): # Weight tensor must be 8-bit op = testutil.create_op_with_quant_tensors( - Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16 + Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16 ) op.attrs = {"stride_w": 1, "stride_h": 1} assert not support.is_operator_supported(op) @@ -216,7 +143,7 @@ def test_constraint_weights_type(): def test_constraint_weights_const(): # Weight tensor cannot be non-const tensors - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) op.attrs = {"stride_w": 1, "stride_h": 1} weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights") weights.quantization = testutil.default_quant_params() @@ -226,7 +153,7 @@ def test_constraint_weights_const(): def test_constraint_weights_limit(): # Sum of weights has a limit - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1} op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]]) assert not support.is_operator_supported(op) @@ -252,28 +179,11 @@ def test_constraint_bias_40bit(): def test_constraint_batch_size(): - op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1} assert not support.is_operator_supported(op) -def test_constraint_quant_scale_inf(): - # Test handling IFM scale/OFM scale is infinite - op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8]) - op.ifm.quantization.scale_f32 = np.float32(1e9) - op.ofm.quantization.scale_f32 = np.float32(1e-35) - assert not support.is_operator_supported(op) - - -def test_constraint_ofm_scale_too_small(): - # Tests handling of OFM scale < 1e-38 - shp = [1, 10, 20, 16] - op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),) - assert support.is_operator_supported(op) - op.ofm.quantization.scale_f32 = 1e-43 - assert not support.is_operator_supported(op) - - def test_constraint_depth_multiplier(): # Valid. Depth multiplier is 1 so no further constraints op = testutil.create_op_with_quant_tensors( @@ -339,23 +249,6 @@ def test_constraint_tconv_valid(): assert not support.is_operator_supported(op) -def test_constraint_matching_in_out_types(): - # Valid - op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8]) - op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME} - assert support.is_operator_supported(op) - # Invalid. datatypes for ifm and ofm must match (default uint8) - op.ifm.dtype = DataType.int8 - assert not support.is_operator_supported(op) - - -def test_constraint_filter_type(): - # Filter width/height must be integers - op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8]) - op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME} - assert not support.is_operator_supported(op) - - def test_constraint_filter_range(): # Avg pool restrictions are dependent on padding: # SAME padding restricts both W and H to max 8 @@ -434,36 +327,6 @@ def test_constraint_resize(): assert not support.is_operator_supported(op) -def test_constraint_matching_shapes(): - # Softmax requires the ifm and ofm shapes to match - op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4]) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8]) - assert support.is_operator_supported(op) - - -def test_constraint_beta_value_range(): - # beta must be positive - op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8]) - op.attrs["beta"] = -1.0 - assert not support.is_operator_supported(op) - op.attrs["beta"] = 0.0 - assert support.is_operator_supported(op) - - -def test_constraint_splitv_inferred(): - # SplitV requires a maximum of one inferred shape (-1) - qp = testutil.default_quant_params() - op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) - sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp) - op.add_input_tensor(sizes) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8]) - sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp) - op.add_input_tensor(sizes) - assert support.is_operator_supported(op) - - def test_constraint_concat_pass(): # A working concat op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) @@ -474,59 +337,6 @@ def test_constraint_concat_pass(): assert support.is_operator_supported(op) -def test_constraint_axis_exists(): - # Missing axis attribute - op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) - ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2") - ifm2.quantization = testutil.default_quant_params() - op.add_input_tensor(ifm2) - assert not support.is_operator_supported(op) - - -def test_constraint_axis_valid(): - # Invalid axis attribute - op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) - ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2") - ifm2.quantization = testutil.default_quant_params() - op.add_input_tensor(ifm2) - op.attrs["axis"] = 7 - assert not support.is_operator_supported(op) - - -def test_constraint_matching_dimensionality(): - # Mismatching dimensionality: 4D+2D=4D - op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) - ifm2 = Tensor([1, 4], DataType.uint8, "in2") - ifm2.quantization = testutil.default_quant_params() - op.add_input_tensor(ifm2) - op.attrs["axis"] = 3 - assert not support.is_operator_supported(op) - - -def test_constraint_valid_dimensions(): - # Mismatching dimension value: - # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm - op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) - ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2") - ifm2.quantization = testutil.default_quant_params() - op.add_input_tensor(ifm2) - op.attrs["axis"] = 3 - assert not support.is_operator_supported(op) - - -def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets): - qp = testutil.default_quant_params() - in0 = Tensor(in_shape, DataType.uint8, "in") - in0.quantization = qp - in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp) - in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp) - in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp) - out = Tensor(out_shape, DataType.uint8, "out") - out.quantization = qp - attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0} - return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs) - - def create_pad_op( in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32, ): @@ -540,14 +350,6 @@ def create_pad_op( return op -def test_constraint_pad_input_count(): - # Incorrect number of input tensors (2) - op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],) - assert support.is_operator_supported(op) - op.add_input_tensor(op.inputs[0].clone()) - assert not support.is_operator_supported(op) - - def test_constraint_padded_dimensions(): # Incorrect padding dimensions, can only pad width and height op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],) @@ -582,6 +384,19 @@ def test_constraint_pad_dtype(): assert not support.is_operator_supported(op) +def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets): + qp = testutil.default_quant_params() + in0 = Tensor(in_shape, DataType.uint8, "in") + in0.quantization = qp + in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp) + in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp) + in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp) + out = Tensor(out_shape, DataType.uint8, "out") + out.quantization = qp + attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0} + return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs) + + def create_strided_slice(): # Creates a valid strided slice operator with some valid inputs/outputs op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0]) @@ -591,26 +406,6 @@ def create_strided_slice(): return op -def test_constraint_stridedslice_input_count(): - # Wrong number of input tensors - op = create_strided_slice() - op.add_input_tensor(op.inputs[0].clone()) - assert not support.is_operator_supported(op) - - -def test_constraint_stridedslice_inputs_const(): - # begin, end, stride values must not be None - op = create_strided_slice() - op.inputs[1].values = None - assert not support.is_operator_supported(op) - op = create_strided_slice() - op.inputs[2].values = None - assert not support.is_operator_supported(op) - op = create_strided_slice() - op.inputs[3].values = None - assert not support.is_operator_supported(op) - - def test_constraint_stridedslice_stride_values(): # Unsupported strides op = create_strided_slice() @@ -618,70 +413,6 @@ def test_constraint_stridedslice_stride_values(): assert not support.is_operator_supported(op) -def test_constraint_ellipsis_mask(): - # Unsupported ellipsis mask - op = create_strided_slice() - op.attrs["ellipsis_mask"] = 1 - assert not support.is_operator_supported(op) - - -def test_constraint_axis_masks(): - op = create_strided_slice() - # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok - op.attrs["new_axis_mask"] = 2 - assert support.is_operator_supported(op) - op = create_strided_slice() - op.attrs["shrink_axis_mask"] = 3 - assert support.is_operator_supported(op) - # But setting both to non-zero is not supported - op.attrs["new_axis_mask"] = 2 - assert not support.is_operator_supported(op) - - -def test_constraint_slice_ranges(): - # Examples where end offset <= begin offset - op = create_strided_slice() - op.inputs[1].values = [0, 7, 2, 0] - assert not support.is_operator_supported(op) - op = create_strided_slice() - op.inputs[2].values = [0, 7, 2, 0] - assert not support.is_operator_supported(op) - op = create_strided_slice() - op.attrs["begin_mask"] = 0 - assert not support.is_operator_supported(op) - op = create_strided_slice() - op.attrs["end_mask"] = 0 - assert not support.is_operator_supported(op) - - -def test_constraint_matching_inputs_types(): - # input data types must match (default is uint8) - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8]) - op.ifm2.dtype = DataType.int8 - assert not support.is_operator_supported(op) - - -def test_constraint_matching_signed(): - # signed inputs require output to also be signed - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) - op.ofm.dtype = DataType.uint8 - assert not support.is_operator_supported(op) - - -def test_constraint_unsigned_valid(): - # unsigned inputs require output to be either: - op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8]) - # the same (default uint8) - assert support.is_operator_supported(op) - op.ofm.dtype = DataType.int8 - assert not support.is_operator_supported(op) - op.ofm.dtype = DataType.int16 - assert not support.is_operator_supported(op) - # or int32 - op.ofm.dtype = DataType.int32 - assert support.is_operator_supported(op) - - def test_constraint_inputs_int32(): # both inputs must be type int32 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8]) @@ -750,28 +481,6 @@ def test_constraint_elemwise_batch_size(): assert not support.is_operator_supported(op) -def test_constraint_matching_either_shapes(): - # BINARY CASE - # At least one ifm shape must match ofm's shape - op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4]) - assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4]) - assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2]) - assert not support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16]) - assert not support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16]) - assert not support.is_operator_supported(op) - - # UNARY CASE - # No second input so this is treated the same as requiring ifm shape to match ofm shape - op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32) - assert support.is_operator_supported(op) - op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32) - assert not support.is_operator_supported(op) - - def test_constraint_broadcast_shapes(): # BINARY CASE # Only allow broadcast to 1 dim, for 1 rank index @@ -800,46 +509,6 @@ def test_constraint_broadcast_shapes(): assert not support.is_operator_supported(op) -def test_constraint_alpha_valid(): - # Alpha cannot be negative - op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) - op.attrs["alpha"] = 0 - assert support.is_operator_supported(op) - op.attrs["alpha"] = -1 - assert not support.is_operator_supported(op) - - -def test_constraint_hardswish_dtype(): - # HardSwish operator dtype should be int8 or uint8, and input dtype must match output - # UINT8 - op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8]) - assert support.is_operator_supported(op) - # INT8 - op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) - assert support.is_operator_supported(op) - - # Invalid - op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16) - assert not support.is_operator_supported(op) - op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32) - assert not support.is_operator_supported(op) - - in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in") - out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out") - op = testutil.create_op(Op.HardSwish, [in_tens], out_tens) - assert not support.is_operator_supported(op) - - -def test_constraint_keep_dims_ifm_ofm(): - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4]) - op.attrs["keep_num_dims"] = True - assert not support.is_operator_supported(op) - op.attrs["keep_num_dims"] = False - assert support.is_operator_supported(op) - - def create_mean(input_shape, output_shape, axis, datatype, attrs): ifm = Tensor(input_shape, datatype, "in") ifm.quantization = testutil.default_quant_params() @@ -853,33 +522,6 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs): return op -def test_mean_dtype(): - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) - assert support.is_operator_supported(op) - op.ifm.dtype = DataType.int16 - op.ofm.dtype = DataType.int16 - assert not support.is_operator_supported(op) - - -def test_mean_axis(): - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) - assert not support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True}) - assert not support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True}) - assert not support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) - assert not support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) - assert support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) - assert support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True}) - assert support.is_operator_supported(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True}) - assert support.is_operator_supported(op) - - def test_mean_hw_product(): op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {}) assert support.is_operator_supported(op) |