aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorJonas Ohlsson <jonas.ohlsson@arm.com>2021-07-26 16:13:12 +0200
committerJonas Ohlsson <jonas.ohlsson@arm.com>2021-07-27 11:06:27 +0200
commit45e653dbd81633b8d78215b16a9b2205e39dd8e2 (patch)
tree18b3073eac45e9e8d69a616ae96d7a3fbdef9663 /ethosu
parentc2449827ec55f49b6087e3e385fb3c4f6776dc6a (diff)
downloadethos-u-vela-45e653dbd81633b8d78215b16a9b2205e39dd8e2.tar.gz
MLBEDSW-4853: Refactor supported operators
Refactor supported operators by breaking out model semantics into its own class. Model semantics checked right after model read. Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com> Change-Id: If442b189efcd91dda01af60b2b3adedfacdf2fad
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/architecture_features.py4
-rw-r--r--ethosu/vela/model_reader.py41
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py460
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py (renamed from ethosu/vela/test/test_supported_operators.py)408
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py2
-rw-r--r--ethosu/vela/tflite_model_semantic.py527
-rw-r--r--ethosu/vela/tflite_supported_operators.py (renamed from ethosu/vela/supported_operators.py)553
-rw-r--r--ethosu/vela/tosa_model_semantic.py57
-rw-r--r--ethosu/vela/vela.py117
9 files changed, 1245 insertions, 924 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 98d3d8c2..aaf1ae45 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -32,12 +32,12 @@ from .numeric_util import round_up_to_int
from .operation import Kernel
from .operation import NpuBlockType
from .operation import PointXYZ
-from .supported_operators import SupportedOperators
from .tensor import BandwidthDirection
from .tensor import MemArea
from .tensor import MemType
from .tensor import TensorFormat
from .tensor import TensorPurpose
+from .tflite_supported_operators import TFLiteSupportedOperators
from .tosa_supported_operators import TosaSupportedOperators
@@ -398,7 +398,7 @@ class ArchitectureFeatures:
self.generate_block_config_map(Block(ifm_block_max.width * 2, ifm_block_max.height, 128))
# Setup supported operators and restriction checkers class
- self.supported_operators = SupportedOperators()
+ self.tflite_supported_operators = TFLiteSupportedOperators()
self.tosa_supported_operators = TosaSupportedOperators()
# Returns available number of SHRAM banks depending on activation lookup table
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index f48645d3..3b094361 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -15,7 +15,9 @@
# limitations under the License.
# Description:
# Dispatcher for reading a neural network model.
+from . import tflite_model_semantic
from . import tflite_reader
+from . import tosa_model_semantic
from . import tosa_reader
from .errors import InputFileError
from .nn_graph import NetworkType
@@ -39,16 +41,17 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis
output_node_names = []
if initialisation_nodes is None:
initialisation_nodes = []
- return (
- tflite_reader.read_tflite(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- ),
- NetworkType.TFLite,
+
+ nng = tflite_reader.read_tflite(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
)
+ nng = tflite_model_semantic.tflite_semantic_checker(nng)
+
+ return (nng, NetworkType.TFLite)
elif fname.endswith(".tosa"):
if feed_dict is None:
feed_dict = {}
@@ -57,15 +60,15 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis
if initialisation_nodes is None:
initialisation_nodes = []
- return (
- tosa_reader.read_tosa(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- ),
- NetworkType.TOSA,
+ nng = tosa_reader.read_tosa(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
)
+ nng = tosa_model_semantic.tosa_semantic_checker(nng)
+
+ return (nng, NetworkType.TOSA)
else:
raise InputFileError(fname, "Unsupported file extension. Only .tflite and .tosa files are supported")
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
new file mode 100644
index 00000000..4c329844
--- /dev/null
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -0,0 +1,460 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for tflite_model_semantic
+import numpy as np
+
+from ethosu.vela.data_type import DataType
+from ethosu.vela.operation import Op
+from ethosu.vela.operation import Padding
+from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
+from ethosu.vela.tensor import Tensor
+from ethosu.vela.test import testutil
+from ethosu.vela.tflite_model_semantic import TFLiteSemantic
+
+semantic_checker = TFLiteSemantic()
+
+
+def test_constraint_tens_no_dynamic():
+ # Tensors cannot be dynamic (no shape, not a scalar)
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_defined_shape():
+ # Tensors cannot have None in them
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_output_scalar():
+ # Scalar output is not allowed at all:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ op.ofm.values = 0.5
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_input_scalar():
+ # Shapeless input is allowed if its of a certain type:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Invalid shapeless input due to op type:
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
+ op.ifm.values = 0.5
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_shape_size():
+ # Tensors cannot be > 4D
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_quant_none_check():
+ # Tensors must have quantization parameters
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_tens_quant_scale():
+ # Quantization scale cannot be infinite
+ qp = QuantizationParameters()
+ qp.zero_point = 0
+ qp.scale_f32 = np.inf
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_fc_output_2d_not_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_fc_output_2d_is_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_conv_pass():
+ # First test a simple conv passes
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 1, "stride_h": 1}
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_stride_type():
+ # Stride width and height must be integer types
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 1.5, "stride_h": "1"}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_dilation_type():
+ # Dilation width and height must be integer types
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_quant_scale_inf():
+ # Test handling IFM scale/OFM scale is infinite
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm.quantization.scale_f32 = np.float32(1e9)
+ op.ofm.quantization.scale_f32 = np.float32(1e-35)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_ofm_scale_too_small():
+ # Tests handling of OFM scale < 1e-38
+ shp = [1, 10, 20, 16]
+ op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.quantization.scale_f32 = 1e-43
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_in_out_types():
+ # Valid
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Invalid. datatypes for ifm and ofm must match (default uint8)
+ op.ifm.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_filter_type():
+ # Filter width/height must be integers
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_shapes():
+ # Softmax requires the ifm and ofm shapes to match
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_beta_value_range():
+ # beta must be positive
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
+ op.attrs["beta"] = -1.0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["beta"] = 0.0
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_splitv_inferred():
+ # SplitV requires a maximum of one inferred shape (-1)
+ qp = testutil.default_quant_params()
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_concat_pass():
+ # A working concat
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_exists():
+ # Missing axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_valid():
+ # Invalid axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 7
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_dimensionality():
+ # Mismatching dimensionality: 4D+2D=4D
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_valid_dimensions():
+ # Mismatching dimension value:
+ # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
+ op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, DataType.uint8, "in")
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
+ out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
+ attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+
+
+def create_pad_op(
+ in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
+):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, in_dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ out = Tensor(out_shape, out_dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+ return op
+
+
+def test_constraint_pad_input_count():
+ # Incorrect number of input tensors (2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_strided_slice():
+ # Creates a valid strided slice operator with some valid inputs/outputs
+ op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
+ op.attrs["begin_mask"] = 1
+ op.attrs["end_mask"] = 9
+ assert semantic_checker.is_operator_semantic_valid(op)
+ return op
+
+
+def test_constraint_stridedslice_input_count():
+ # Wrong number of input tensors
+ op = create_strided_slice()
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_stridedslice_inputs_const():
+ # begin, end, stride values must not be None
+ op = create_strided_slice()
+ op.inputs[1].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[2].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[3].values = None
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_ellipsis_mask():
+ # Unsemantic_checkered ellipsis mask
+ op = create_strided_slice()
+ op.attrs["ellipsis_mask"] = 1
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_axis_masks():
+ op = create_strided_slice()
+ # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
+ op.attrs["new_axis_mask"] = 2
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["shrink_axis_mask"] = 3
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # But setting both to non-zero is not semantic_checkered
+ op.attrs["new_axis_mask"] = 2
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_slice_ranges():
+ # Examples where end offset <= begin offset
+ op = create_strided_slice()
+ op.inputs[1].values = [0, 7, 2, 0]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.inputs[2].values = [0, 7, 2, 0]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["begin_mask"] = 0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_strided_slice()
+ op.attrs["end_mask"] = 0
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_inputs_types():
+ # input data types must match (default is uint8)
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm2.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_signed():
+ # signed inputs require output to also be signed
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ op.ofm.dtype = DataType.uint8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_unsigned_valid():
+ # unsigned inputs require output to be either:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ # the same (default uint8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.dtype = DataType.int8
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.dtype = DataType.int16
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # or int32
+ op.ofm.dtype = DataType.int32
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_matching_either_shapes():
+ # BINARY CASE
+ # At least one ifm shape must match ofm's shape
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ # UNARY CASE
+ # No second input so this is treated the same as requiring ifm shape to match ofm shape
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_alpha_valid():
+ # Alpha cannot be negative
+ op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
+ op.attrs["alpha"] = 0
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["alpha"] = -1
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_hardswish_dtype():
+ # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
+ # UINT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # INT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
+ out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_keep_dims_ifm_ofm():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ op.attrs["keep_num_dims"] = True
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["keep_num_dims"] = False
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
+ ifm = Tensor(input_shape, datatype, "in")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor(output_shape, datatype, "out")
+ ofm.quantization = testutil.default_quant_params()
+ if type(axis) is list:
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ elif type(axis) is int:
+ indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+ op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
+ return op
+
+
+def test_mean_dtype():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.ifm.dtype = DataType.int16
+ op.ofm.dtype = DataType.int16
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_mean_axis():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 38308154..af5dc174 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -15,55 +15,20 @@
# limitations under the License.
#
# Description:
-# Unit tests for support_operators
+# Unit tests for tflite support_operators
import numpy as np
from ethosu.vela.data_type import DataType
from ethosu.vela.operation import ActivationFunction
from ethosu.vela.operation import Op
from ethosu.vela.operation import Padding
-from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
+from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
-support = SupportedOperators()
-
-
-def test_constraint_tens_no_dynamic():
- # Tensors cannot be dynamic (no shape, not a scalar)
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_defined_shape():
- # Tensors cannot have None in them
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_output_scalar():
- # Scalar output is not allowed at all:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
- op.ofm.values = 0.5
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_input_scalar():
- # Shapeless input is allowed if its of a certain type:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
- assert support.is_operator_supported(op)
- # Invalid shapeless input due to op type:
- op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
- op.ifm.values = 0.5
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_shape_size():
- # Tensors cannot be > 4D
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
- assert not support.is_operator_supported(op)
+support = TFLiteSupportedOperators()
def test_constraint_tens_dtype():
@@ -86,21 +51,6 @@ def test_constraint_tens_dimension():
assert not support.is_operator_supported(op)
-def test_constraint_tens_quant_none_check():
- # Tensors must have quantization parameters
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_tens_quant_scale():
- # Quantization scale cannot be infinite
- qp = QuantizationParameters()
- qp.zero_point = 0
- qp.scale_f32 = np.inf
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
- assert not support.is_operator_supported(op)
-
-
def test_constraint_tens_quant_per_axis_not_supp():
# Quantization scale cannot be array-valued for elemwise ops
qp = QuantizationParameters()
@@ -123,15 +73,6 @@ def test_constraint_tens_quant_per_axis_is_supp():
assert support.is_operator_supported(op)
-def test_constraint_fc_output_2d_not_supp():
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
- assert not support.is_operator_supported(op)
-
-
def test_constraint_fc_output_2d_is_supp():
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
assert support.is_operator_supported(op)
@@ -158,49 +99,35 @@ def test_constraint_faf_ofm_dtype():
def test_constraint_conv_pass():
# First test a simple conv passes
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert support.is_operator_supported(op)
-def test_constraint_stride_type():
- # Stride width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 1.5, "stride_h": "1"}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_stride_range():
# Stride width and height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 0, "stride_h": 20}
assert not support.is_operator_supported(op)
-def test_constraint_dilation_type():
- # Dilation width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_dilation_range():
# Dilation width and height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20}
assert not support.is_operator_supported(op)
def test_constraint_dilated_height_range():
# Dilated kernel height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
def test_constraint_dilated_product_range():
# Dilated kernel width x height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
@@ -208,7 +135,7 @@ def test_constraint_dilated_product_range():
def test_constraint_weights_type():
# Weight tensor must be 8-bit
op = testutil.create_op_with_quant_tensors(
- Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
+ Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
)
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
@@ -216,7 +143,7 @@ def test_constraint_weights_type():
def test_constraint_weights_const():
# Weight tensor cannot be non-const tensors
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1}
weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
weights.quantization = testutil.default_quant_params()
@@ -226,7 +153,7 @@ def test_constraint_weights_const():
def test_constraint_weights_limit():
# Sum of weights has a limit
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
assert not support.is_operator_supported(op)
@@ -252,28 +179,11 @@ def test_constraint_bias_40bit():
def test_constraint_batch_size():
- op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
-def test_constraint_quant_scale_inf():
- # Test handling IFM scale/OFM scale is infinite
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
- op.ifm.quantization.scale_f32 = np.float32(1e9)
- op.ofm.quantization.scale_f32 = np.float32(1e-35)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_ofm_scale_too_small():
- # Tests handling of OFM scale < 1e-38
- shp = [1, 10, 20, 16]
- op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
- assert support.is_operator_supported(op)
- op.ofm.quantization.scale_f32 = 1e-43
- assert not support.is_operator_supported(op)
-
-
def test_constraint_depth_multiplier():
# Valid. Depth multiplier is 1 so no further constraints
op = testutil.create_op_with_quant_tensors(
@@ -339,23 +249,6 @@ def test_constraint_tconv_valid():
assert not support.is_operator_supported(op)
-def test_constraint_matching_in_out_types():
- # Valid
- op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
- assert support.is_operator_supported(op)
- # Invalid. datatypes for ifm and ofm must match (default uint8)
- op.ifm.dtype = DataType.int8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_filter_type():
- # Filter width/height must be integers
- op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
- assert not support.is_operator_supported(op)
-
-
def test_constraint_filter_range():
# Avg pool restrictions are dependent on padding:
# SAME padding restricts both W and H to max 8
@@ -434,36 +327,6 @@ def test_constraint_resize():
assert not support.is_operator_supported(op)
-def test_constraint_matching_shapes():
- # Softmax requires the ifm and ofm shapes to match
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
- assert support.is_operator_supported(op)
-
-
-def test_constraint_beta_value_range():
- # beta must be positive
- op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
- op.attrs["beta"] = -1.0
- assert not support.is_operator_supported(op)
- op.attrs["beta"] = 0.0
- assert support.is_operator_supported(op)
-
-
-def test_constraint_splitv_inferred():
- # SplitV requires a maximum of one inferred shape (-1)
- qp = testutil.default_quant_params()
- op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
- op.add_input_tensor(sizes)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
- op.add_input_tensor(sizes)
- assert support.is_operator_supported(op)
-
-
def test_constraint_concat_pass():
# A working concat
op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
@@ -474,59 +337,6 @@ def test_constraint_concat_pass():
assert support.is_operator_supported(op)
-def test_constraint_axis_exists():
- # Missing axis attribute
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_axis_valid():
- # Invalid axis attribute
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 7
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_dimensionality():
- # Mismatching dimensionality: 4D+2D=4D
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 3
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_valid_dimensions():
- # Mismatching dimension value:
- # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
- op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
- ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
- ifm2.quantization = testutil.default_quant_params()
- op.add_input_tensor(ifm2)
- op.attrs["axis"] = 3
- assert not support.is_operator_supported(op)
-
-
-def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
- qp = testutil.default_quant_params()
- in0 = Tensor(in_shape, DataType.uint8, "in")
- in0.quantization = qp
- in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
- in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
- in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
- out = Tensor(out_shape, DataType.uint8, "out")
- out.quantization = qp
- attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
- return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
-
-
def create_pad_op(
in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
):
@@ -540,14 +350,6 @@ def create_pad_op(
return op
-def test_constraint_pad_input_count():
- # Incorrect number of input tensors (2)
- op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
- assert support.is_operator_supported(op)
- op.add_input_tensor(op.inputs[0].clone())
- assert not support.is_operator_supported(op)
-
-
def test_constraint_padded_dimensions():
# Incorrect padding dimensions, can only pad width and height
op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
@@ -582,6 +384,19 @@ def test_constraint_pad_dtype():
assert not support.is_operator_supported(op)
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, DataType.uint8, "in")
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
+ out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
+ attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
@@ -591,26 +406,6 @@ def create_strided_slice():
return op
-def test_constraint_stridedslice_input_count():
- # Wrong number of input tensors
- op = create_strided_slice()
- op.add_input_tensor(op.inputs[0].clone())
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_stridedslice_inputs_const():
- # begin, end, stride values must not be None
- op = create_strided_slice()
- op.inputs[1].values = None
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[2].values = None
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[3].values = None
- assert not support.is_operator_supported(op)
-
-
def test_constraint_stridedslice_stride_values():
# Unsupported strides
op = create_strided_slice()
@@ -618,70 +413,6 @@ def test_constraint_stridedslice_stride_values():
assert not support.is_operator_supported(op)
-def test_constraint_ellipsis_mask():
- # Unsupported ellipsis mask
- op = create_strided_slice()
- op.attrs["ellipsis_mask"] = 1
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_axis_masks():
- op = create_strided_slice()
- # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
- op.attrs["new_axis_mask"] = 2
- assert support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["shrink_axis_mask"] = 3
- assert support.is_operator_supported(op)
- # But setting both to non-zero is not supported
- op.attrs["new_axis_mask"] = 2
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_slice_ranges():
- # Examples where end offset <= begin offset
- op = create_strided_slice()
- op.inputs[1].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[2].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["begin_mask"] = 0
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["end_mask"] = 0
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_inputs_types():
- # input data types must match (default is uint8)
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
- op.ifm2.dtype = DataType.int8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_matching_signed():
- # signed inputs require output to also be signed
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
- op.ofm.dtype = DataType.uint8
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_unsigned_valid():
- # unsigned inputs require output to be either:
- op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
- # the same (default uint8)
- assert support.is_operator_supported(op)
- op.ofm.dtype = DataType.int8
- assert not support.is_operator_supported(op)
- op.ofm.dtype = DataType.int16
- assert not support.is_operator_supported(op)
- # or int32
- op.ofm.dtype = DataType.int32
- assert support.is_operator_supported(op)
-
-
def test_constraint_inputs_int32():
# both inputs must be type int32
op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
@@ -750,28 +481,6 @@ def test_constraint_elemwise_batch_size():
assert not support.is_operator_supported(op)
-def test_constraint_matching_either_shapes():
- # BINARY CASE
- # At least one ifm shape must match ofm's shape
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
- assert not support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
- assert not support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
- assert not support.is_operator_supported(op)
-
- # UNARY CASE
- # No second input so this is treated the same as requiring ifm shape to match ofm shape
- op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
- assert support.is_operator_supported(op)
- op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
- assert not support.is_operator_supported(op)
-
-
def test_constraint_broadcast_shapes():
# BINARY CASE
# Only allow broadcast to 1 dim, for 1 rank index
@@ -800,46 +509,6 @@ def test_constraint_broadcast_shapes():
assert not support.is_operator_supported(op)
-def test_constraint_alpha_valid():
- # Alpha cannot be negative
- op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
- op.attrs["alpha"] = 0
- assert support.is_operator_supported(op)
- op.attrs["alpha"] = -1
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_hardswish_dtype():
- # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
- # UINT8
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
- assert support.is_operator_supported(op)
- # INT8
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
- assert support.is_operator_supported(op)
-
- # Invalid
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
- assert not support.is_operator_supported(op)
- op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
- assert not support.is_operator_supported(op)
-
- in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
- out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_keep_dims_ifm_ofm():
- op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
- op.attrs["keep_num_dims"] = True
- assert not support.is_operator_supported(op)
- op.attrs["keep_num_dims"] = False
- assert support.is_operator_supported(op)
-
-
def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
@@ -853,33 +522,6 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs):
return op
-def test_mean_dtype():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op.ifm.dtype = DataType.int16
- op.ofm.dtype = DataType.int16
- assert not support.is_operator_supported(op)
-
-
-def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
- assert support.is_operator_supported(op)
-
-
def test_mean_hw_product():
op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
assert support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 9fdff8ff..68b4e8ea 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1530,7 +1530,7 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
def supported_operator_check(op, arch, nng):
- op.run_on_npu = arch.supported_operators.is_operator_supported(op)
+ op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
return op
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
new file mode 100644
index 00000000..c8b373a3
--- /dev/null
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -0,0 +1,527 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
+from collections import defaultdict
+
+import numpy as np
+
+from .data_type import BaseType
+from .data_type import DataType
+from .numeric_util import is_integer
+from .operation import get_slice_offsets
+from .operation import Op
+from .supported_operators_util import docstring_format_args
+from .supported_operators_util import list_formatter
+from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
+from .tflite_mapping import optype_to_builtintype
+
+
+def _optype_formatter(op_list):
+ # Convert internal op types to external names
+ output = map(optype_to_builtintype, op_list)
+ # Remove UNKNOWNs
+ output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
+ return list_formatter(output)
+
+
+class TFLiteSemantic:
+ # Categorised lists of operators
+ convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
+ depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
+ transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
+ convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
+ max_pooling_ops = Op.op_set(Op.is_maxpool_op)
+ avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
+ pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+ unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
+ binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
+ binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
+ binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
+ binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+ elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
+ shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
+
+ def __init__(self):
+ # Setup the generic constraints. Note: the order matters
+ self.generic_constraints = []
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
+
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
+ self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
+ self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
+
+ # Setup specific constraints. Note: the order matters
+ self.specific_constraints = defaultdict(list)
+
+ # Conv-like checks:
+ for op_type in TFLiteSemantic.convolution_like_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
+
+ # Pooling checks:
+ for op_type in TFLiteSemantic.pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
+ # AVG pooling specific checks:
+ for op_type in TFLiteSemantic.avg_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
+ # MAX pooling specific checks:
+ for op_type in TFLiteSemantic.max_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
+
+ # Concat specific checks:
+ for op_type in (Op.Concat, Op.ConcatTFLite):
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
+
+ # Element-wise checks:
+ for op_type in TFLiteSemantic.elem_wise_main_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
+ # Unary specific checks:
+ for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
+ # Binary Min/Max specific checks:
+ for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
+ # Binary Add/Mul/Sub specific checks:
+ for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
+
+ # Softmax specific checks:
+ self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
+ self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
+ self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
+
+ # SplitV specific checks:
+ self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
+
+ # StridedSlice specific checks:
+ self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
+ self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
+ self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
+ self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
+ self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
+
+ # LeakyRelu specific checks:
+ self.specific_constraints[Op.LeakyRelu].append(TFLiteSemantic.constraint_alpha_valid)
+
+ # FullyConnected specific checks:
+ self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
+ self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
+
+ # Pad specific checks:
+ self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
+ self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
+
+ # HardSwish specific checks:
+ self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
+ self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
+ # Mean specific checks:
+ self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
+ self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
+ self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
+
+ def is_operator_semantic_valid(self, op):
+ ext_type = optype_to_builtintype(op.type)
+
+ if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+ return True
+
+ for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+ valid, extra = constraint(op)
+ if not valid:
+ print(
+ f"Warning: unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
+ )
+ print(f" - {constraint.__doc__}")
+ if extra:
+ print(f" {extra}")
+ return False
+
+ return True
+
+ @staticmethod
+ def constraint_tens_no_dynamic(op):
+ "Input(s) and Output tensors must not be dynamic"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if (tens.shape == []) and (tens.values is None):
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has dynamic tensor(s): {extra}"
+
+ @staticmethod
+ def constraint_tens_defined_shape(op):
+ "Input(s) and Output tensors must have a defined shape"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if not tens.has_fully_defined_shape():
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_tens_output_scalar(op):
+ "Output tensors cannot be scalar"
+ ofm = op.ofm
+ valid = ofm.shape != []
+ return valid, f"Output Tensor '{ofm.name}' is scalar"
+
+ @classmethod
+ @docstring_format_args([_optype_formatter(shapeless_input_ops)])
+ def constraint_tens_input_scalar(cls, op):
+ "Scalar Input tensors are only valid for op type: {}"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has scalar input tensor(s): {extra}"
+
+ @staticmethod
+ def constraint_tens_shape_size(op):
+ "Input(s) and Output tensors must not be greater than 4D"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if len(tens.shape) > 4:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_tens_quant_none_check(op):
+ "Input(s), Output and Weight tensors must have quantization parameters"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ for tens in tensors:
+ if tens.quantization is None:
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has tensors with missing quantization parameters: {extra}"
+
+ @staticmethod
+ def constraint_tens_quant_scale(op):
+ "Input(s), Output and Weight tensors with quantization scales must be finite"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ for tens in tensors:
+ if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
+ return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_fc_output_2d(op):
+ "The output tensor(s) must have 2D shape"
+ valid = True
+ extra = []
+ for tens in op.outputs:
+ if len(tens.shape) != 2:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D")
+ return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_stride_type(op):
+ "Stride values for both width and height must be integer types"
+ w, h = op.get_kernel_stride()
+ valid = is_integer(w) and is_integer(h)
+ return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
+
+ @staticmethod
+ def constraint_dilation_type(op):
+ "Dilation factor values for both width and height must be integer types"
+ w, h = op.get_kernel_dilation()
+ valid = is_integer(w) and is_integer(h)
+ return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
+
+ @staticmethod
+ def constraint_quant_scale_inf(op):
+ "Input and Output tensors must have quantization scales that fit within float32 precision"
+ if op.ofm is not None and op.ofm.is_quantized():
+ ofm_scale = op.ofm.quantization.scale_f32
+ if ofm_scale < np.finfo(np.float32).tiny:
+ return (
+ False,
+ f"The quantization scale of the output tensor is {ofm_scale}, "
+ + f"minimum supported is: {np.finfo(np.float32).tiny}",
+ )
+ if op.ifm is not None and op.ifm.is_quantized():
+ ifm_scale = op.ifm.quantization.scale_f32
+ if np.isinf(ifm_scale / ofm_scale):
+ return (
+ False,
+ f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
+ )
+ return True, "Op's quantization is ok"
+
+ @staticmethod
+ def constraint_matching_in_out_types(op):
+ "IFM and OFM data types must match"
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ valid = ifm_dtype == ofm_dtype
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
+
+ @staticmethod
+ def constraint_beta_value_range(op):
+ "Beta value needs to be positive"
+ beta = op.attrs.get("beta", 1.0)
+ valid = beta >= 0
+ return valid, f"Op has beta={beta}"
+
+ @staticmethod
+ def constraint_filter_type(op):
+ "Kernel filter values for both width and height must be integer types"
+ w = op.kernel.width
+ h = op.kernel.height
+ valid = is_integer(w) and is_integer(h)
+ return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
+
+ @staticmethod
+ def constraint_matching_shapes(op):
+ "IFM and OFM shapes must match"
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ valid = ifm_shape == ofm_shape
+ return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
+
+ @staticmethod
+ def constraint_splitv_inferred(op):
+ "Only one size is allowed to be inferred"
+ sizes = op.inputs[1].values
+ valid = np.count_nonzero(sizes == -1) <= 1
+ return valid, f"Op has multiple inferred sizes (-1): {sizes}"
+
+ @staticmethod
+ def constraint_axis_exists(op):
+ "Axis attribute must exist"
+ axis = op.attrs.get("axis")
+ valid = axis is not None
+ return valid, f"Op has axis={axis}"
+
+ @staticmethod
+ def constraint_axis_valid(op):
+ "Axis attribute must be in the range [0, <ofm_dimensions>)"
+ dims = len(op.ofm.shape)
+ axis = op.attrs["axis"]
+ axis += dims if axis < 0 else 0
+ valid = 0 <= axis < dims
+ return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
+
+ @staticmethod
+ def constraint_matching_dimensionality(op):
+ "All Input dimensionalities must match OFM dimensionality"
+ valid = True
+ extra = []
+ ofm_dim = len(op.ofm.shape)
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ dim = len(tens.shape)
+ if dim != ofm_dim:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
+ extra = ", ".join(extra)
+ return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
+
+ @staticmethod
+ def constraint_valid_dimensions(op):
+ "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
+ valid = True
+ extra = []
+ ofm_shape = op.ofm.shape
+ ofm_dim = len(ofm_shape)
+ axis = op.attrs["axis"]
+ axis += ofm_dim if axis < 0 else 0
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ extra = ", ".join(extra)
+ return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
+
+ @staticmethod
+ def constraint_stridedslice_input_count(op):
+ "Exactly 4 Input tensors are required"
+ inputs = len(op.inputs)
+ valid = inputs == 4
+ return valid, f"Op has {inputs} inputs"
+
+ @staticmethod
+ def constraint_pad_input_count(op):
+ "Number of input tensors must be exactly 2"
+ inputs = len(op.inputs)
+ valid = inputs == 2
+ return valid, f"Op has {inputs} inputs"
+
+ @staticmethod
+ def constraint_pad_constant(op):
+ "The padding tensor must be constant"
+ pad_tensor = op.inputs[1].values
+ valid = pad_tensor is not None
+ return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
+
+ @staticmethod
+ def constraint_stridedslice_inputs_const(op):
+ "Begin, End and Stride Input tensors must be constant"
+ valid = True
+ extra = []
+ _, begin, end, strides = op.inputs
+ if begin.values is None:
+ valid = False
+ extra.append(f"Begin tensor '{begin.name}'")
+ if end.values is None:
+ valid = False
+ extra.append(f"End tensor '{end.name}'")
+ if strides.values is None:
+ valid = False
+ extra.append(f"Stride tensor '{strides.name}'")
+ extra = ", ".join(extra)
+ return valid, f"Op has non-constant tensors: {extra}"
+
+ @staticmethod
+ def constraint_ellipsis_mask(op):
+ "ellipsis_mask must be 0"
+ ellipsis = op.attrs["ellipsis_mask"]
+ valid = ellipsis == 0
+ return valid, f"Op has ellipsis mask as: {ellipsis}"
+
+ @staticmethod
+ def constraint_axis_masks(op):
+ "new_axis_mask and shrink_axis_mask cannot both be set"
+ new_axis = op.attrs["new_axis_mask"]
+ shrink_axis = op.attrs["shrink_axis_mask"]
+ valid = (new_axis == 0) or (shrink_axis == 0)
+ return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
+
+ @staticmethod
+ def constraint_slice_ranges(op):
+ "Slice 'end' values must be greater than 'begin' values"
+ ifm, begin, end, _ = op.inputs
+ # Calculate offset begin/end
+ offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
+ offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
+ # Check "end - begin" doesn't result in any zero or negative elements
+ valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
+ return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
+
+ @staticmethod
+ def constraint_matching_inputs_types(op):
+ "Both Input data types must match"
+ ifm_dtype = op.ifm.dtype
+ ifm2_dtype = op.ifm2.dtype
+ valid = ifm_dtype == ifm2_dtype
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
+
+ @staticmethod
+ def constraint_matching_signed(op):
+ "For IFM that are signed, OFM must also be signed"
+ valid = True
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ if ifm_dtype.type & BaseType.Signed:
+ valid = bool(ofm_dtype.type & BaseType.Signed)
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
+
+ @staticmethod
+ def constraint_unsigned_valid(op):
+ "For IFM that are unsigned, OFM must either be the same type or int32"
+ valid = True
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ if ifm_dtype.type & BaseType.Unsigned:
+ valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
+
+ @staticmethod
+ def constraint_input_8bit(op):
+ "IFM must be int8 or uint8"
+ ifm_dtype = op.ifm.dtype
+ valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
+ return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+ @staticmethod
+ def constraint_matching_either_shapes(op):
+ "At least one Input's shape must match the OFM's shape"
+ ifm_shape = op.ifm.shape
+ ifm2_shape = op.ifm2.shape if op.ifm2 else None
+ ofm_shape = op.ofm.shape
+ valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
+ return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
+
+ @staticmethod
+ def constraint_alpha_valid(op):
+ "Alpha must not be negative"
+ alpha = op.attrs["alpha"]
+ valid = alpha >= 0
+ return valid, f"Op has alpha={alpha}"
+
+ @staticmethod
+ def constraint_keep_dim_ifm_ofm(op):
+ "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
+ valid = True
+ if op.attrs.get("keep_num_dims"):
+ valid = len(op.ifm.shape) == len(op.ofm.shape)
+ return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
+
+ @staticmethod
+ def constraint_mean_input_dims(op):
+ "Input tensor must be at least 2D"
+ dims = len(op.inputs[0].shape)
+ return 2 <= dims <= 4, f"Input is {dims}D"
+
+ @staticmethod
+ def constraint_mean_axis(op):
+ "Axis indices must correspond to height and width axes"
+ dims = len(op.inputs[0].shape)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ if dims == 2 or dims == 3:
+ valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
+ elif dims == 4:
+ valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
+ return valid, f"Axis is {axis}"
+
+
+def tflite_semantic_checker(nng):
+ semantic_checker = TFLiteSemantic()
+ for sg in nng.subgraphs:
+ for op in sg.get_all_ops():
+ op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
+ return nng
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 663c78f8..cb3d5048 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -14,15 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
-# The SupportedOperators class which is a collection of all supported operators and parameter checks.
+# The TFLiteSupportedOperators class which is a collection of all TFLite supported operators and parameter checks.
from collections import defaultdict
import numpy as np
-from .data_type import BaseType
from .data_type import DataType
-from .numeric_util import is_integer
-from .operation import get_slice_offsets
from .operation import Op
from .operation import Padding
from .supported_operators_util import docstring_format_args
@@ -40,7 +37,7 @@ def _optype_formatter(op_list):
return list_formatter(output)
-class SupportedOperators:
+class TFLiteSupportedOperators:
# Categorised lists of supported operators
npu_pre_ops = set((Op.SplitSliceRead,))
convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
@@ -90,7 +87,6 @@ class SupportedOperators:
split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
memory_only_ops = set((Op.Reshape, Op.QuantizedReshape,)) | concat_ops | split_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops
supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
@@ -112,163 +108,109 @@ class SupportedOperators:
mean_kernel_product = 64 * 64
mean_kernel_product_int8 = 16 * 16
mean_kernel_product_avgpool = 256 * 256
- # Supported consumers
- supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
def __init__(self):
# Setup the generic constraints. Note: the order matters
self.generic_constraints = []
- self.generic_constraints.append(SupportedOperators.constraint_tens_no_dynamic)
- self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
- self.generic_constraints.append(SupportedOperators.constraint_tens_output_scalar)
- self.generic_constraints.append(SupportedOperators.constraint_tens_input_scalar)
- self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
- self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
- self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
- self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
- self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
- self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
- self.generic_constraints.append(SupportedOperators.constraint_tens_quant_per_axis)
- self.generic_constraints.append(SupportedOperators.constraint_faf)
- self.generic_constraints.append(SupportedOperators.constraint_faf_type)
- self.generic_constraints.append(SupportedOperators.constraint_quant_scale_inf)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_dtype)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_int32_ops)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_dimension)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_quant_per_axis)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf_type)
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
# Conv-like checks:
- for op_type in SupportedOperators.convolution_like_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_height_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_product_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_limit)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
+ for op_type in TFLiteSupportedOperators.convolution_like_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilation_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_height_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_product_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_const)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_limit)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
# Depthwise Conv specific checks:
- for op_type in SupportedOperators.depthwise_convolution_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_depth_multiplier)
+ for op_type in TFLiteSupportedOperators.depthwise_convolution_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
# Transpose Conv specific checks:
- for op_type in SupportedOperators.transpose_convolution_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_stride)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_same)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_valid)
+ for op_type in TFLiteSupportedOperators.transpose_convolution_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_stride)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_same)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_valid)
# Pooling checks:
- for op_type in SupportedOperators.pooling_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
+ for op_type in TFLiteSupportedOperators.pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
# AVG pooling specific checks:
- for op_type in SupportedOperators.avg_pooling_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range_valid_pad)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range_valid_pad)
+ for op_type in TFLiteSupportedOperators.avg_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad)
+ self.specific_constraints[op_type].append(
+ TFLiteSupportedOperators.constraint_filter_product_range_valid_pad
+ )
# MAX pooling specific checks:
- for op_type in SupportedOperators.max_pooling_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range)
+ for op_type in TFLiteSupportedOperators.max_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_product_range)
# Resizing specific checks:
- for op_type in SupportedOperators.resizing_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_resize)
+ for op_type in TFLiteSupportedOperators.resizing_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize)
# Vector Product specific checks:
- for op_type in SupportedOperators.fc_vector_products:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
-
- # Concat specific checks:
- for op_type in (Op.Concat, Op.ConcatTFLite):
- self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_exists)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_valid)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_dimensionality)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_valid_dimensions)
+ for op_type in TFLiteSupportedOperators.fc_vector_products:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_const)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
# Element-wise checks:
- for op_type in SupportedOperators.elem_wise_main_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_elemwise_batch_size)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_either_shapes)
- # Unary specific checks:
- for op_type in SupportedOperators.unary_elem_wise_main_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+ for op_type in TFLiteSupportedOperators.elem_wise_main_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_elemwise_batch_size)
# Binary Min/Max specific checks:
- for op_type in SupportedOperators.binary_elem_wise_min_max_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
+ for op_type in TFLiteSupportedOperators.binary_elem_wise_min_max_ops:
+ self.specific_constraints[op_type].append(
+ TFLiteSupportedOperators.constraint_matching_quantization_parameters
+ )
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_broadcast_shapes)
# Binary Add/Mul/Sub specific checks:
- for op_type in SupportedOperators.binary_elem_wise_add_mul_sub:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
+ for op_type in TFLiteSupportedOperators.binary_elem_wise_add_mul_sub:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_broadcast_shapes)
# Binary Shift specific checks:
- for op_type in SupportedOperators.binary_elem_wise_shift_ops:
- self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32)
- self.specific_constraints[op_type].append(SupportedOperators.constraint_broadcast_shapes)
+ for op_type in TFLiteSupportedOperators.binary_elem_wise_shift_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_inputs_int32)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_broadcast_shapes)
# SHL specific checks:
- self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32)
+ self.specific_constraints[Op.SHL].append(TFLiteSupportedOperators.constraint_output_int32)
# CLZ specific checks:
- self.specific_constraints[Op.CLZ].append(SupportedOperators.constraint_output_int32)
-
- # Softmax specific checks:
- self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_shapes)
- self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_in_out_types)
- self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_beta_value_range)
-
- # SplitV specific checks:
- self.specific_constraints[Op.SplitV].append(SupportedOperators.constraint_splitv_inferred)
+ self.specific_constraints[Op.CLZ].append(TFLiteSupportedOperators.constraint_output_int32)
# StridedSlice specific checks:
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_slice_ranges)
-
- # LeakyRelu specific checks:
- self.specific_constraints[Op.LeakyRelu].append(SupportedOperators.constraint_alpha_valid)
-
- # FullyConnected specific checks:
- self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_fc_output_2d)
- self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_keep_dim_ifm_ofm)
+ self.specific_constraints[Op.StridedSlice].append(
+ TFLiteSupportedOperators.constraint_stridedslice_stride_values
+ )
# Pad specific checks:
- self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_input_count)
- self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_shape)
- self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_padding_dimensions)
- self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_type)
- self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
+ self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_shape)
+ self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_padding_dimensions)
+ self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type)
- # HardSwish specific checks:
- self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
- self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
# Mean specific checks:
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit)
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims)
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis)
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_avgpool)
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product)
- self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
- if op.type not in SupportedOperators.supported_operators:
+ if op.type not in TFLiteSupportedOperators.supported_operators:
if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
print(f"Info: {ext_type} '{op.name}' is a CPU only op")
return False
@@ -284,64 +226,6 @@ class SupportedOperators:
return True
- @staticmethod
- def constraint_tens_no_dynamic(op):
- "Input(s) and Output tensors must not be dynamic"
- valid = True
- extra = []
- tensors = [tens for tens in op.inputs + op.outputs if tens]
- for tens in tensors:
- if (tens.shape == []) and (tens.values is None):
- valid = False
- extra.append(tens.name)
- extra = ", ".join(extra)
- return valid, f"Op has dynamic tensor(s): {extra}"
-
- @staticmethod
- def constraint_tens_defined_shape(op):
- "Input(s) and Output tensors must have a defined shape"
- valid = True
- extra = []
- tensors = [tens for tens in op.inputs + op.outputs if tens]
- for tens in tensors:
- if not tens.has_fully_defined_shape():
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
- return valid, ", ".join(extra)
-
- @staticmethod
- def constraint_tens_output_scalar(op):
- "Output tensors cannot be scalar"
- ofm = op.ofm
- valid = ofm.shape != []
- return valid, f"Output Tensor '{ofm.name}' is scalar"
-
- @classmethod
- @docstring_format_args([_optype_formatter(shapeless_input_ops)])
- def constraint_tens_input_scalar(cls, op):
- "Scalar Input tensors are only valid for op type: {}"
- valid = True
- extra = []
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
- valid = False
- extra.append(tens.name)
- extra = ", ".join(extra)
- return valid, f"Op has scalar input tensor(s): {extra}"
-
- @staticmethod
- def constraint_tens_shape_size(op):
- "Input(s) and Output tensors must not be greater than 4D"
- valid = True
- extra = []
- tensors = [tens for tens in op.inputs + op.outputs if tens]
- for tens in tensors:
- if len(tens.shape) > 4:
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
- return valid, ", ".join(extra)
-
@classmethod
@docstring_format_args([list_formatter(supported_op_dtypes)])
def constraint_tens_dtype(cls, op):
@@ -389,31 +273,6 @@ class SupportedOperators:
extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
- @staticmethod
- def constraint_tens_quant_none_check(op):
- "Input(s), Output and Weight tensors must have quantization parameters"
- valid = True
- extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- for tens in tensors:
- if tens.quantization is None:
- valid = False
- extra.append(tens.name)
- extra = ", ".join(extra)
- return valid, f"Op has tensors with missing quantization parameters: {extra}"
-
- @staticmethod
- def constraint_tens_quant_scale(op):
- "Input(s), Output and Weight tensors with quantization scales must be finite"
- valid = True
- extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- for tens in tensors:
- if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
- valid = False
- extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
- return valid, ", ".join(extra)
-
@classmethod
@docstring_format_args([_optype_formatter(per_axis_quant_ops)])
def constraint_tens_quant_per_axis(cls, op):
@@ -428,17 +287,6 @@ class SupportedOperators:
extra.append(tens.name)
return valid, "The following tensor(s) have per-axis quantization parameters: " + ", ".join(extra)
- @staticmethod
- def constraint_fc_output_2d(op):
- "The output tensor(s) must have 2D shape"
- valid = True
- extra = []
- for tens in op.outputs:
- if len(tens.shape) != 2:
- valid = False
- extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D")
- return valid, ", ".join(extra)
-
@classmethod
@docstring_format_args([_optype_formatter(supported_fused_activations)])
def constraint_faf(cls, op):
@@ -463,13 +311,6 @@ class SupportedOperators:
res = valid, f"Op has fused activation function {ext_type}, and Output tensor data type: {op.ofm.dtype}"
return res
- @staticmethod
- def constraint_stride_type(op):
- "Stride values for both width and height must be integer types"
- w, h = op.get_kernel_stride()
- valid = is_integer(w) and is_integer(h)
- return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
-
@classmethod
@docstring_format_args(stride_range)
def constraint_stride_range(cls, op):
@@ -479,13 +320,6 @@ class SupportedOperators:
valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
return valid, f"Op has stride WxH as: {w}x{h}"
- @staticmethod
- def constraint_dilation_type(op):
- "Dilation factor values for both width and height must be integer types"
- w, h = op.get_kernel_dilation()
- valid = is_integer(w) and is_integer(h)
- return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
-
@classmethod
@docstring_format_args(dilation_range)
def constraint_dilation_range(cls, op):
@@ -564,26 +398,6 @@ class SupportedOperators:
return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}"
@staticmethod
- def constraint_quant_scale_inf(op):
- "Input and Output tensors must have quantization scales that fit within float32 precision"
- if op.ofm is not None and op.ofm.is_quantized():
- ofm_scale = op.ofm.quantization.scale_f32
- if ofm_scale < np.finfo(np.float32).tiny:
- return (
- False,
- f"The quantization scale of the output tensor is {ofm_scale}, "
- + f"minimum supported is: {np.finfo(np.float32).tiny}",
- )
- if op.ifm is not None and op.ifm.is_quantized():
- ifm_scale = op.ifm.quantization.scale_f32
- if np.isinf(ifm_scale / ofm_scale):
- return (
- False,
- f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
- )
- return True, "Op's quantization is ok"
-
- @staticmethod
def constraint_depth_multiplier(op):
"For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"
depth_multiplier = op.attrs.get("depth_multiplier", 1)
@@ -639,29 +453,6 @@ class SupportedOperators:
return valid, extra
return True, "Op has padding=SAME"
- @staticmethod
- def constraint_matching_in_out_types(op):
- "IFM and OFM data types must match"
- ifm_dtype = op.ifm.dtype
- ofm_dtype = op.ofm.dtype
- valid = ifm_dtype == ofm_dtype
- return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
-
- @staticmethod
- def constraint_beta_value_range(op):
- "Beta value needs to be positive"
- beta = op.attrs.get("beta", 1.0)
- valid = beta >= 0
- return valid, f"Op has beta={beta}"
-
- @staticmethod
- def constraint_filter_type(op):
- "Kernel filter values for both width and height must be integer types"
- w = op.kernel.width
- h = op.kernel.height
- valid = is_integer(w) and is_integer(h)
- return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
-
@classmethod
@docstring_format_args(filter_range)
def constraint_filter_range(cls, op):
@@ -697,7 +488,7 @@ class SupportedOperators:
def constraint_filter_height_range_valid_pad(op):
"VALID padding: Kernel filter height must be in the range [{}, {}]"
if op.attrs["padding"] == Padding.VALID:
- return SupportedOperators.constraint_filter_height_range(op)
+ return TFLiteSupportedOperators.constraint_filter_height_range(op)
return True, "Op has padding=SAME"
@staticmethod
@@ -705,7 +496,7 @@ class SupportedOperators:
def constraint_filter_product_range_valid_pad(op):
"VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
if op.attrs["padding"] == Padding.VALID:
- return SupportedOperators.constraint_filter_product_range(op)
+ return TFLiteSupportedOperators.constraint_filter_product_range(op)
return True, "Op has padding=SAME"
@staticmethod
@@ -738,83 +529,6 @@ class SupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
@staticmethod
- def constraint_matching_shapes(op):
- "IFM and OFM shapes must match"
- ifm_shape = op.ifm.shape
- ofm_shape = op.ofm.shape
- valid = ifm_shape == ofm_shape
- return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
-
- @staticmethod
- def constraint_splitv_inferred(op):
- "Only one size is allowed to be inferred"
- sizes = op.inputs[1].values
- valid = np.count_nonzero(sizes == -1) <= 1
- return valid, f"Op has multiple inferred sizes (-1): {sizes}"
-
- @staticmethod
- def constraint_axis_exists(op):
- "Axis attribute must exist"
- axis = op.attrs.get("axis")
- valid = axis is not None
- return valid, f"Op has axis={axis}"
-
- @staticmethod
- def constraint_axis_valid(op):
- "Axis attribute must be in the range [0, <ofm_dimensions>)"
- dims = len(op.ofm.shape)
- axis = op.attrs["axis"]
- axis += dims if axis < 0 else 0
- valid = 0 <= axis < dims
- return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
-
- @staticmethod
- def constraint_matching_dimensionality(op):
- "All Input dimensionalities must match OFM dimensionality"
- valid = True
- extra = []
- ofm_dim = len(op.ofm.shape)
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- dim = len(tens.shape)
- if dim != ofm_dim:
- valid = False
- extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
- extra = ", ".join(extra)
- return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
-
- @staticmethod
- def constraint_valid_dimensions(op):
- "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
- valid = True
- extra = []
- ofm_shape = op.ofm.shape
- ofm_dim = len(ofm_shape)
- axis = op.attrs["axis"]
- axis += ofm_dim if axis < 0 else 0
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
- extra = ", ".join(extra)
- return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
-
- @staticmethod
- def constraint_stridedslice_input_count(op):
- "Exactly 4 Input tensors are required"
- inputs = len(op.inputs)
- valid = inputs == 4
- return valid, f"Op has {inputs} inputs"
-
- @staticmethod
- def constraint_pad_input_count(op):
- "Number of input tensors must be exactly 2"
- inputs = len(op.inputs)
- valid = inputs == 2
- return valid, f"Op has {inputs} inputs"
-
- @staticmethod
def constraint_pad_shape(op):
"The padding tensor must have the shape [3,2] or [4,2]"
valid = op.inputs[1].shape in ([3, 2], [4, 2])
@@ -839,31 +553,6 @@ class SupportedOperators:
return valid, f"First dimension padding: {pad_tensor[0,:]}, last dimension padding: {pad_tensor[-1,:]}"
@staticmethod
- def constraint_pad_constant(op):
- "The padding tensor must be constant"
- pad_tensor = op.inputs[1].values
- valid = pad_tensor is not None
- return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
-
- @staticmethod
- def constraint_stridedslice_inputs_const(op):
- "Begin, End and Stride Input tensors must be constant"
- valid = True
- extra = []
- _, begin, end, strides = op.inputs
- if begin.values is None:
- valid = False
- extra.append(f"Begin tensor '{begin.name}'")
- if end.values is None:
- valid = False
- extra.append(f"End tensor '{end.name}'")
- if strides.values is None:
- valid = False
- extra.append(f"Stride tensor '{strides.name}'")
- extra = ", ".join(extra)
- return valid, f"Op has non-constant tensors: {extra}"
-
- @staticmethod
def constraint_stridedslice_stride_values(op):
"All Strides values must be 1"
strides = op.inputs[3]
@@ -871,60 +560,6 @@ class SupportedOperators:
return valid, f"Op has strides values {strides.values}"
@staticmethod
- def constraint_ellipsis_mask(op):
- "ellipsis_mask must be 0"
- ellipsis = op.attrs["ellipsis_mask"]
- valid = ellipsis == 0
- return valid, f"Op has ellipsis mask as: {ellipsis}"
-
- @staticmethod
- def constraint_axis_masks(op):
- "new_axis_mask and shrink_axis_mask cannot both be set"
- new_axis = op.attrs["new_axis_mask"]
- shrink_axis = op.attrs["shrink_axis_mask"]
- valid = (new_axis == 0) or (shrink_axis == 0)
- return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
-
- @staticmethod
- def constraint_slice_ranges(op):
- "Slice 'end' values must be greater than 'begin' values"
- ifm, begin, end, _ = op.inputs
- # Calculate offset begin/end
- offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
- offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
- # Check "end - begin" doesn't result in any zero or negative elements
- valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
- return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
-
- @staticmethod
- def constraint_matching_inputs_types(op):
- "Both Input data types must match"
- ifm_dtype = op.ifm.dtype
- ifm2_dtype = op.ifm2.dtype
- valid = ifm_dtype == ifm2_dtype
- return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
-
- @staticmethod
- def constraint_matching_signed(op):
- "For IFM that are signed, OFM must also be signed"
- valid = True
- ifm_dtype = op.ifm.dtype
- ofm_dtype = op.ofm.dtype
- if ifm_dtype.type & BaseType.Signed:
- valid = bool(ofm_dtype.type & BaseType.Signed)
- return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
-
- @staticmethod
- def constraint_unsigned_valid(op):
- "For IFM that are unsigned, OFM must either be the same type or int32"
- valid = True
- ifm_dtype = op.ifm.dtype
- ofm_dtype = op.ofm.dtype
- if ifm_dtype.type & BaseType.Unsigned:
- valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
- return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
-
- @staticmethod
def constraint_inputs_int32(op):
"Both Input data types must be int32"
ifm_dtype = op.ifm.dtype
@@ -940,13 +575,6 @@ class SupportedOperators:
return valid, f"Op has ofm_dtype={ofm_dtype}"
@staticmethod
- def constraint_input_8bit(op):
- "IFM must be int8 or uint8"
- ifm_dtype = op.ifm.dtype
- valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
- return valid, f"Op has ifm_dtype={ifm_dtype}"
-
- @staticmethod
def constraint_matching_quantization_parameters(op):
"Both Input quantization parameters must match OFM quantization parameters"
valid = True
@@ -975,15 +603,6 @@ class SupportedOperators:
return valid, f"Op has invalid input tensors: {extra}"
@staticmethod
- def constraint_matching_either_shapes(op):
- "At least one Input's shape must match the OFM's shape"
- ifm_shape = op.ifm.shape
- ifm2_shape = op.ifm2.shape if op.ifm2 else None
- ofm_shape = op.ofm.shape
- valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
- return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
-
- @staticmethod
def constraint_broadcast_shapes(op):
"Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2"
ifm_shape = op.ifm.shape
@@ -1004,38 +623,6 @@ class SupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
- @staticmethod
- def constraint_alpha_valid(op):
- "Alpha must not be negative"
- alpha = op.attrs["alpha"]
- valid = alpha >= 0
- return valid, f"Op has alpha={alpha}"
-
- @staticmethod
- def constraint_keep_dim_ifm_ofm(op):
- "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
- valid = True
- if op.attrs.get("keep_num_dims"):
- valid = len(op.ifm.shape) == len(op.ofm.shape)
- return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
-
- @staticmethod
- def constraint_mean_input_dims(op):
- "Input tensor must be at least 2D"
- dims = len(op.inputs[0].shape)
- return 2 <= dims <= 4, f"Input is {dims}D"
-
- @staticmethod
- def constraint_mean_axis(op):
- "Axis indices must correspond to height and width axes"
- dims = len(op.inputs[0].shape)
- axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
- if dims == 2 or dims == 3:
- valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
- elif dims == 4:
- valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
- return valid, f"Axis is {axis}"
-
@classmethod
@docstring_format_args([mean_kernel_product_avgpool])
def constraint_mean_height_width_product_avgpool(cls, op):
diff --git a/ethosu/vela/tosa_model_semantic.py b/ethosu/vela/tosa_model_semantic.py
new file mode 100644
index 00000000..5cd186c6
--- /dev/null
+++ b/ethosu/vela/tosa_model_semantic.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# The TosaSemantic class which is a collection of TOSA model semantic checks.
+from collections import defaultdict
+
+from .operation import Op
+from .tosa_mapping import optype_to_tosa_op_type
+
+
+class TosaSemantic:
+ # TODO populate this
+
+ def __init__(self):
+ # Setup the generic constraints. Note: the order matters
+ self.generic_constraints = []
+
+ # Setup specific constraints. Note: the order matters
+ self.specific_constraints = defaultdict(list)
+
+ def is_operator_semantic_valid(self, op):
+ ext_type = optype_to_tosa_op_type(op.type)
+
+ if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+ return True
+
+ for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+ valid, extra = constraint(op)
+ if not valid:
+ print(f"Warning: unsupported TOSA semantics for {ext_type} '{op.name}'.")
+ print(f" - {constraint.__doc__}")
+ if extra:
+ print(f" {extra}")
+ return False
+
+ return True
+
+
+def tosa_semantic_checker(nng):
+ semantic_checker = TosaSemantic()
+ for sg in nng.subgraphs:
+ for op in sg.get_all_ops():
+ op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
+ return nng
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index ecdc7aa7..b8a3b9f2 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -35,14 +35,18 @@ from .api import API_VERSION
from .debug_database import DebugDatabase
from .errors import InputFileError
from .errors import VelaError
+from .nn_graph import NetworkType
from .nn_graph import PassPlacement
from .nn_graph import TensorAllocator
-from .supported_operators import SupportedOperators
from .tensor import MemArea
from .tensor import Tensor
from .tflite.Model import Model
from .tflite_mapping import builtin_operator_map
from .tflite_mapping import builtin_type_name
+from .tflite_model_semantic import TFLiteSemantic
+from .tflite_supported_operators import TFLiteSupportedOperators
+from .tosa_model_semantic import TosaSemantic
+from .tosa_supported_operators import TosaSupportedOperators
from ethosu.vela.architecture_features import ArchitectureFeatures
@@ -169,50 +173,91 @@ def generate_supported_ops():
"This file complies with",
"[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)",
"",
- "## Summary Table",
- "",
- "The table below contains TFLite operators that can be placed on the Ethos-U NPU. ",
- "If the constraints are not met, then that operator will be scheduled on the CPU instead. ",
- "For any other TFLite operator not listed, will be left untouched and scheduled on the CPU. ",
- "Please check the supported operator list for your chosen runtime for further information.",
- "",
- "| Operator | Constraints |",
- "| --- | --- |",
+ "Summary table of constraints for:",
]
- supported = SupportedOperators()
- op_constraint_links = []
- op_list = sorted(((op, builtin_type_name(op)) for op in builtin_operator_map), key=lambda x: x[1])
- for op, name in op_list:
- internal_op = builtin_operator_map[op][0]
- if internal_op in SupportedOperators.supported_operators:
- links = "[Generic](#generic-constraints)"
- if internal_op in supported.specific_constraints:
- links += f", [Specific](#{name.lower()}-constraints)"
- op_constraint_links.append((internal_op, name))
- lines.append(f"| {name} | {links} |")
- lines += [
- "",
- "## Generic Constraints",
- "",
- "This is a list of constraints that all NPU operators must satisfy in order to be scheduled on the NPU.",
- "",
- ]
- for constraint in supported.generic_constraints:
- # Markdown needs two spaces at the end of a line to render it as a separate line
- reason = constraint.__doc__.replace("\n", " \n")
- lines.append(f"- {reason}")
- for op, name in op_constraint_links:
+
+ for network_type in NetworkType:
+ lines += [
+ f"- [{network_type.name}](#{network_type.name.lower()}-summary-table)",
+ ]
+
+ for network_type in NetworkType:
+ lines += [
+ "",
+ f"## {network_type.name} Summary Table",
+ "",
+ ]
+ if network_type == NetworkType.TFLite:
+ lines += [
+ "The table below contains TFLite operators that can be placed on the Ethos-U NPU. ",
+ "If the constraints are not met, then that operator will be scheduled on the CPU instead. ",
+ "For any other TFLite operator not listed, will be left untouched and scheduled on the CPU. ",
+ "Please check the supported operator list for your chosen runtime for further information.",
+ "",
+ "| Operator | TFLite Constraints |",
+ "| --- | --- |",
+ ]
+ semantic_checker = TFLiteSemantic()
+ supported = TFLiteSupportedOperators()
+ elif network_type == NetworkType.TOSA:
+ lines += [
+ "The table below contains TOSA operators that can be placed on the Ethos-U NPU. ",
+ "Note: There is limited support for compiling a TOSA neural network (EXPERIMENTAL). ",
+ "The related constraints have not yet been populated in the list.",
+ "",
+ "| Operator | TOSA Constraints |",
+ "| --- | --- |",
+ ]
+ semantic_checker = TosaSemantic()
+ supported = TosaSupportedOperators()
+ else:
+ raise ValueError
+
+ op_constraint_links = []
+ op_list = sorted(((op, builtin_type_name(op)) for op in builtin_operator_map), key=lambda x: x[1])
+ for op, name in op_list:
+ internal_op = builtin_operator_map[op][0]
+ if internal_op in TFLiteSupportedOperators.supported_operators:
+ links = f"[Generic](#{network_type.name.lower()}-generic-constraints)"
+ if (
+ internal_op in supported.specific_constraints
+ or internal_op in semantic_checker.specific_constraints
+ ):
+ links += f", [Specific](#{network_type.name.lower()}-{name.lower()}-constraints)"
+ op_constraint_links.append((internal_op, name))
+ lines.append(f"| {name} | {links} |")
lines += [
"",
- f"## {name} Constraints",
+ f"### {network_type.name} Generic Constraints",
"",
- f"This is a list of constraints that the {name} operator must satisfy in order to be scheduled on the NPU.",
+ "This is a list of constraints that all NPU operators must satisfy in order to be scheduled on the NPU.",
"",
]
- for constraint in supported.specific_constraints[op]:
+ for constraint in semantic_checker.generic_constraints:
+ # Markdown needs two spaces at the end of a line to render it as a separate line
+ reason = constraint.__doc__.replace("\n", " \n")
+ lines.append(f"- {reason}")
+ for constraint in supported.generic_constraints:
# Markdown needs two spaces at the end of a line to render it as a separate line
reason = constraint.__doc__.replace("\n", " \n")
lines.append(f"- {reason}")
+ for op, name in op_constraint_links:
+ lines += [
+ "",
+ f"### {network_type.name} {name} Constraints",
+ "",
+ f"This is a list of constraints that the {name} operator must satisfy in order to be scheduled on the"
+ " NPU.",
+ "",
+ ]
+ for constraint in semantic_checker.specific_constraints[op]:
+ # Markdown needs two spaces at the end of a line to render it as a separate line
+ reason = constraint.__doc__.replace("\n", " \n")
+ lines.append(f"- {reason}")
+ for constraint in supported.specific_constraints[op]:
+ # Markdown needs two spaces at the end of a line to render it as a separate line
+ reason = constraint.__doc__.replace("\n", " \n")
+ lines.append(f"- {reason}")
# Note. this will generate the file in the CWD
filepath = os.path.join(os.getcwd(), "SUPPORTED_OPS.md")