aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
authorJonas Ohlsson <jonas.ohlsson@arm.com>2021-07-26 16:13:12 +0200
committerJonas Ohlsson <jonas.ohlsson@arm.com>2021-07-27 11:06:27 +0200
commit45e653dbd81633b8d78215b16a9b2205e39dd8e2 (patch)
tree18b3073eac45e9e8d69a616ae96d7a3fbdef9663 /ethosu/vela/test
parentc2449827ec55f49b6087e3e385fb3c4f6776dc6a (diff)
downloadethos-u-vela-45e653dbd81633b8d78215b16a9b2205e39dd8e2.tar.gz
MLBEDSW-4853: Refactor supported operators
Refactor supported operators by breaking out model semantics into its own class. Model semantics checked right after model read. Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com> Change-Id: If442b189efcd91dda01af60b2b3adedfacdf2fad
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py460
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py (renamed from ethosu/vela/test/test_supported_operators.py)408
2 files changed, 485 insertions, 383 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
new file mode 100644
index 00000000..4c329844
--- /dev/null
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -0,0 +1,460 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for tflite_model_semantic
+import numpy as np
+
+from ethosu.vela.data_type import DataType
+from ethosu.vela.operation import Op
+from ethosu.vela.operation import Padding
+from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
+from ethosu.vela.tensor import Tensor
+from ethosu.vela.test import testutil
+from ethosu.vela.tflite_model_semantic import TFLiteSemantic
+
+semantic_checker = TFLiteSemantic()
+
+
+def test_constraint_tens_no_dynamic():
+ # Tensors cannot be dynamic (no shape, not a scalar)
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_defined_shape():
+ # Tensors cannot have None in them
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_output_scalar():
+ # Scalar output is not allowed at all:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ op.ofm.values = 0.5
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_input_scalar():
+ # Shapeless input is allowed if its of a certain type:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Invalid shapeless input due to op type:
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
+ op.ifm.values = 0.5
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_shape_size():
+ # Tensors cannot be > 4D
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_quant_none_check():
+ # Tensors must have quantization parameters
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_quant_scale():
+ # Quantization scale cannot be infinite
+ qp = QuantizationParameters()
+ qp.zero_point = 0
+ qp.scale_f32 = np.inf
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_fc_output_2d_not_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_fc_output_2d_is_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_conv_pass():
+ # First test a simple conv passes
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 1, "stride_h": 1}
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_stride_type():
+ # Stride width and height must be integer types
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 1.5, "stride_h": "1"}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_dilation_type():
+ # Dilation width and height must be integer types
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_quant_scale_inf():
+ # Test handling IFM scale/OFM scale is infinite
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm.quantization.scale_f32 = np.float32(1e9)
+ op.ofm.quantization.scale_f32 = np.float32(1e-35)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_ofm_scale_too_small():
+ # Tests handling of OFM scale < 1e-38
+ shp = [1, 10, 20, 16]
+ op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.quantization.scale_f32 = 1e-43
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_in_out_types():
+ # Valid
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Invalid. datatypes for ifm and ofm must match (default uint8)
+ op.ifm.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_filter_type():
+ # Filter width/height must be integers
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_shapes():
+ # Softmax requires the ifm and ofm shapes to match
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_beta_value_range():
+ # beta must be positive
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
+ op.attrs["beta"] = -1.0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["beta"] = 0.0
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_splitv_inferred():
+ # SplitV requires a maximum of one inferred shape (-1)
+ qp = testutil.default_quant_params()
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_concat_pass():
+ # A working concat
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_exists():
+ # Missing axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_valid():
+ # Invalid axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 7
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_dimensionality():
+ # Mismatching dimensionality: 4D+2D=4D
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_valid_dimensions():
+ # Mismatching dimension value:
+ # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, DataType.uint8, "in")
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
+ out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
+ attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+
+
+def create_pad_op(
+ in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
+):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, in_dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ out = Tensor(out_shape, out_dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+ return op
+
+
+def test_constraint_pad_input_count():
+ # Incorrect number of input tensors (2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_strided_slice():
+ # Creates a valid strided slice operator with some valid inputs/outputs
+ op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
+ op.attrs["begin_mask"] = 1
+ op.attrs["end_mask"] = 9
+ assert semantic_checker.is_operator_semantic_valid(op)
+ return op
+
+
+def test_constraint_stridedslice_input_count():
+ # Wrong number of input tensors
+ op = create_strided_slice()
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_stridedslice_inputs_const():
+ # begin, end, stride values must not be None
+ op = create_strided_slice()
+ op.inputs[1].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[2].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[3].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_ellipsis_mask():
+ # Unsemantic_checkered ellipsis mask
+ op = create_strided_slice()
+ op.attrs["ellipsis_mask"] = 1
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_masks():
+ op = create_strided_slice()
+ # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
+ op.attrs["new_axis_mask"] = 2
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["shrink_axis_mask"] = 3
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # But setting both to non-zero is not semantic_checkered
+ op.attrs["new_axis_mask"] = 2
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_slice_ranges():
+ # Examples where end offset <= begin offset
+ op = create_strided_slice()
+ op.inputs[1].values = [0, 7, 2, 0]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[2].values = [0, 7, 2, 0]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["begin_mask"] = 0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["end_mask"] = 0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_inputs_types():
+ # input data types must match (default is uint8)
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm2.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_signed():
+ # signed inputs require output to also be signed
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ op.ofm.dtype = DataType.uint8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_unsigned_valid():
+ # unsigned inputs require output to be either:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ # the same (default uint8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.dtype = DataType.int16
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # or int32
+ op.ofm.dtype = DataType.int32
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_either_shapes():
+ # BINARY CASE
+ # At least one ifm shape must match ofm's shape
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ # UNARY CASE
+ # No second input so this is treated the same as requiring ifm shape to match ofm shape
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_alpha_valid():
+ # Alpha cannot be negative
+ op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
+ op.attrs["alpha"] = 0
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["alpha"] = -1
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_hardswish_dtype():
+ # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
+ # UINT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # INT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
+ out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_keep_dims_ifm_ofm():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ op.attrs["keep_num_dims"] = True
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["keep_num_dims"] = False
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
+ ifm = Tensor(input_shape, datatype, "in")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor(output_shape, datatype, "out")
+ ofm.quantization = testutil.default_quant_params()
+ if type(axis) is list:
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ elif type(axis) is int:
+ indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+ op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
+ return op
+
+
+def test_mean_dtype():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ifm.dtype = DataType.int16
+ op.ofm.dtype = DataType.int16
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_mean_axis():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 38308154..af5dc174 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -15,55 +15,20 @@
# limitations under the License.
#
# Description:
-# Unit tests for support_operators
+# Unit tests for tflite support_operators
import numpy as np
from ethosu.vela.data_type import DataType
from ethosu.vela.operation import ActivationFunction
from ethosu.vela.operation import Op
from ethosu.vela.operation import Padding
-from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
+from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
-support = SupportedOperators()
-
-
-def test_constraint_tens_no_dynamic():
- # Tensors cannot be dynamic (no shape, not a scalar)
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_defined_shape():
- # Tensors cannot have None in them
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_output_scalar():
- # Scalar output is not allowed at all:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
- op.ofm.values = 0.5
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_input_scalar():
- # Shapeless input is allowed if its of a certain type:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
- assert support.is_operator_supported(op)
- # Invalid shapeless input due to op type:
- op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
- op.ifm.values = 0.5
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_shape_size():
- # Tensors cannot be > 4D
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
- assert not support.is_operator_supported(op)
+support = TFLiteSupportedOperators()
def test_constraint_tens_dtype():
@@ -86,21 +51,6 @@ def test_constraint_tens_dimension():
assert not support.is_operator_supported(op)
-def test_constraint_tens_quant_none_check():
- # Tensors must have quantization parameters
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_quant_scale():
- # Quantization scale cannot be infinite
- qp = QuantizationParameters()
- qp.zero_point = 0
- qp.scale_f32 = np.inf
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
- assert not support.is_operator_supported(op)
-
-
def test_constraint_tens_quant_per_axis_not_supp():
# Quantization scale cannot be array-valued for elemwise ops
qp = QuantizationParameters()
@@ -123,15 +73,6 @@ def test_constraint_tens_quant_per_axis_is_supp():
assert support.is_operator_supported(op)
-def test_constraint_fc_output_2d_not_supp():
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
- assert not support.is_operator_supported(op)
-
-
def test_constraint_fc_output_2d_is_supp():
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
assert support.is_operator_supported(op)
@@ -158,49 +99,35 @@ def test_constraint_faf_ofm_dtype():
def test_constraint_conv_pass():
# First test a simple conv passes
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert support.is_operator_supported(op)
-def test_constraint_stride_type():
- # Stride width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 1.5, "stride_h": "1"}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_stride_range():
# Stride width and height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 0, "stride_h": 20}
assert not support.is_operator_supported(op)
-def test_constraint_dilation_type():
- # Dilation width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_dilation_range():
# Dilation width and height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20}
assert not support.is_operator_supported(op)
def test_constraint_dilated_height_range():
# Dilated kernel height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
def test_constraint_dilated_product_range():
# Dilated kernel width x height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
@@ -208,7 +135,7 @@ def test_constraint_dilated_product_range():
def test_constraint_weights_type():
# Weight tensor must be 8-bit
op = testutil.create_op_with_quant_tensors(
- Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
+ Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
)
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
@@ -216,7 +143,7 @@ def test_constraint_weights_type():
def test_constraint_weights_const():
# Weight tensor cannot be non-const tensors
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1}
weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
weights.quantization = testutil.default_quant_params()
@@ -226,7 +153,7 @@ def test_constraint_weights_const():
def test_constraint_weights_limit():
# Sum of weights has a limit
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
assert not support.is_operator_supported(op)
@@ -252,28 +179,11 @@ def test_constraint_bias_40bit():
def test_constraint_batch_size():
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
-def test_constraint_quant_scale_inf():
- # Test handling IFM scale/OFM scale is infinite
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
- op.ifm.quantization.scale_f32 = np.float32(1e9)
- op.ofm.quantization.scale_f32 = np.float32(1e-35)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_ofm_scale_too_small():
- # Tests handling of OFM scale < 1e-38
- shp = [1, 10, 20, 16]
- op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
- assert support.is_operator_supported(op)
- op.ofm.quantization.scale_f32 = 1e-43
- assert not support.is_operator_supported(op)
-
-
def test_constraint_depth_multiplier():
# Valid. Depth multiplier is 1 so no further constraints
op = testutil.create_op_with_quant_tensors(
@@ -339,23 +249,6 @@ def test_constraint_tconv_valid():
assert not support.is_operator_supported(op)
-def test_constraint_matching_in_out_types():
- # Valid
- op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
- assert support.is_operator_supported(op)
- # Invalid. datatypes for ifm and ofm must match (default uint8)
- op.ifm.dtype = DataType.int8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_filter_type():
- # Filter width/height must be integers
- op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_filter_range():
# Avg pool restrictions are dependent on padding:
# SAME padding restricts both W and H to max 8
@@ -434,36 +327,6 @@ def test_constraint_resize():
assert not support.is_operator_supported(op)
-def test_constraint_matching_shapes():
- # Softmax requires the ifm and ofm shapes to match
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
- assert support.is_operator_supported(op)
-
-
-def test_constraint_beta_value_range():
- # beta must be positive
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
- op.attrs["beta"] = -1.0
- assert not support.is_operator_supported(op)
- op.attrs["beta"] = 0.0
- assert support.is_operator_supported(op)
-
-
-def test_constraint_splitv_inferred():
- # SplitV requires a maximum of one inferred shape (-1)
- qp = testutil.default_quant_params()
- op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
- op.add_input_tensor(sizes)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
- op.add_input_tensor(sizes)
- assert support.is_operator_supported(op)
-
-
def test_constraint_concat_pass():
# A working concat
op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
@@ -474,59 +337,6 @@ def test_constraint_concat_pass():
assert support.is_operator_supported(op)
-def test_constraint_axis_exists():
- # Missing axis attribute
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_axis_valid():
- # Invalid axis attribute
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 7
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_dimensionality():
- # Mismatching dimensionality: 4D+2D=4D
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 3
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_valid_dimensions():
- # Mismatching dimension value:
- # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 3
- assert not support.is_operator_supported(op)
-
-
-def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
- qp = testutil.default_quant_params()
- in0 = Tensor(in_shape, DataType.uint8, "in")
- in0.quantization = qp
- in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
- in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
- in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
- out = Tensor(out_shape, DataType.uint8, "out")
- out.quantization = qp
- attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
- return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
-
-
def create_pad_op(
in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
):
@@ -540,14 +350,6 @@ def create_pad_op(
return op
-def test_constraint_pad_input_count():
- # Incorrect number of input tensors (2)
- op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
- assert support.is_operator_supported(op)
- op.add_input_tensor(op.inputs[0].clone())
- assert not support.is_operator_supported(op)
-
-
def test_constraint_padded_dimensions():
# Incorrect padding dimensions, can only pad width and height
op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
@@ -582,6 +384,19 @@ def test_constraint_pad_dtype():
assert not support.is_operator_supported(op)
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, DataType.uint8, "in")
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
+ out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
+ attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
@@ -591,26 +406,6 @@ def create_strided_slice():
return op
-def test_constraint_stridedslice_input_count():
- # Wrong number of input tensors
- op = create_strided_slice()
- op.add_input_tensor(op.inputs[0].clone())
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_stridedslice_inputs_const():
- # begin, end, stride values must not be None
- op = create_strided_slice()
- op.inputs[1].values = None
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[2].values = None
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[3].values = None
- assert not support.is_operator_supported(op)
-
-
def test_constraint_stridedslice_stride_values():
# Unsupported strides
op = create_strided_slice()
@@ -618,70 +413,6 @@ def test_constraint_stridedslice_stride_values():
assert not support.is_operator_supported(op)
-def test_constraint_ellipsis_mask():
- # Unsupported ellipsis mask
- op = create_strided_slice()
- op.attrs["ellipsis_mask"] = 1
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_axis_masks():
- op = create_strided_slice()
- # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
- op.attrs["new_axis_mask"] = 2
- assert support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["shrink_axis_mask"] = 3
- assert support.is_operator_supported(op)
- # But setting both to non-zero is not supported
- op.attrs["new_axis_mask"] = 2
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_slice_ranges():
- # Examples where end offset <= begin offset
- op = create_strided_slice()
- op.inputs[1].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[2].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["begin_mask"] = 0
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["end_mask"] = 0
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_inputs_types():
- # input data types must match (default is uint8)
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
- op.ifm2.dtype = DataType.int8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_signed():
- # signed inputs require output to also be signed
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
- op.ofm.dtype = DataType.uint8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_unsigned_valid():
- # unsigned inputs require output to be either:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
- # the same (default uint8)
- assert support.is_operator_supported(op)
- op.ofm.dtype = DataType.int8
- assert not support.is_operator_supported(op)
- op.ofm.dtype = DataType.int16
- assert not support.is_operator_supported(op)
- # or int32
- op.ofm.dtype = DataType.int32
- assert support.is_operator_supported(op)
-
-
def test_constraint_inputs_int32():
# both inputs must be type int32
op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
@@ -750,28 +481,6 @@ def test_constraint_elemwise_batch_size():
assert not support.is_operator_supported(op)
-def test_constraint_matching_either_shapes():
- # BINARY CASE
- # At least one ifm shape must match ofm's shape
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
- assert not support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
- assert not support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
- assert not support.is_operator_supported(op)
-
- # UNARY CASE
- # No second input so this is treated the same as requiring ifm shape to match ofm shape
- op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
- assert not support.is_operator_supported(op)
-
-
def test_constraint_broadcast_shapes():
# BINARY CASE
# Only allow broadcast to 1 dim, for 1 rank index
@@ -800,46 +509,6 @@ def test_constraint_broadcast_shapes():
assert not support.is_operator_supported(op)
-def test_constraint_alpha_valid():
- # Alpha cannot be negative
- op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
- op.attrs["alpha"] = 0
- assert support.is_operator_supported(op)
- op.attrs["alpha"] = -1
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_hardswish_dtype():
- # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
- # UINT8
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
- assert support.is_operator_supported(op)
- # INT8
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
- assert support.is_operator_supported(op)
-
- # Invalid
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
- assert not support.is_operator_supported(op)
-
- in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
- out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_keep_dims_ifm_ofm():
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
- op.attrs["keep_num_dims"] = True
- assert not support.is_operator_supported(op)
- op.attrs["keep_num_dims"] = False
- assert support.is_operator_supported(op)
-
-
def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
@@ -853,33 +522,6 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs):
return op
-def test_mean_dtype():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op.ifm.dtype = DataType.int16
- op.ofm.dtype = DataType.int16
- assert not support.is_operator_supported(op)
-
-
-def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
-
-
def test_mean_hw_product():
op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
assert support.is_operator_supported(op)