aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
blob: 73a4f282610d1c425052a6fe3cf72a9556ffcb96 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
import numpy as np

from .data_type import BaseType
from .data_type import DataType
from .operation import get_slice_offsets


def warn_cpu(op, msg):
    print("Warning: {} {}, placing on CPU".format(op.type, msg))


class SupportedOperators:
    def __init__(self):
        # Categorised lists of supported operators
        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
        self.depthwise_convolution_ops = set(
            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D,")
        )
        self.transpose_convolution_ops = set(("Conv2DBackpropInput",))
        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
        self.pooling_ops = set(("ReduceSum",)) | self.max_pooling_ops | self.avg_pooling_ops
        self.resizing_ops = set(("ResizeBilinear",))
        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
        self.mac_main_ops = (
            # convolutions
            self.convolution_ops
            # depth-wise convolutions
            | self.depthwise_convolution_ops
            # transpose convolutions
            | self.transpose_convolution_ops
            # pooling
            | self.pooling_ops
            # resizing/upscaling
            | self.resizing_ops
            # FC layers
            | self.fc_vector_products
            # RNN/LSTM/GRU
            | set(("BlockLSTM",))
        )
        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
        self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
        self.binary_elem_wise_add_mul_sub = set(
            ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
        )
        self.binary_elem_wise_main_ops = (
            self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
        )
        self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
        self.activation_ops = set(
            (
                "QuantizedRelu",
                "QuantizedRelu1",
                "QuantizedRelu6",
                "Relu",
                "Relu6",
                "ReluN1To1",
                "Sigmoid",
                "Tanh",
                "Softmax",
            )
        )
        self.npu_post_ops = (
            # activation functions
            self.activation_ops
            # concatenation write direction
            | set(("ConcatSliceWrite",))
            # bias add and batch norm
            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
            # Quantization
            | set(("Quantize",))
        )
        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
        self.memory_only_ops = (
            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | self.concat_ops | self.split_ops
        )
        self.shapeless_input_ops = self.binary_elem_wise_main_ops | set(("Split", "SplitV",))
        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
        self.supported_operators = (
            self.npu_pre_ops | self.mac_main_ops | self.elem_wise_main_ops | self.npu_post_ops | self.memory_only_ops
        )
        # Setup supported operator restriction checkers
        self.supported_operator_restrictions = {}
        self.supported_operator_restrictions.update(
            {op: self.check_convolution_restrictions for op in self.convolution_ops}
        )
        self.supported_operator_restrictions.update(
            {op: self.check_depthwise_convolution_restrictions for op in self.depthwise_convolution_ops}
        )
        self.supported_operator_restrictions.update(
            {op: self.check_transpose_convolution_restrictions for op in self.transpose_convolution_ops}
        )
        self.supported_operator_restrictions.update({op: self.check_pooling_restrictions for op in self.pooling_ops})
        self.supported_operator_restrictions.update({op: self.check_resize_restrictions for op in self.resizing_ops})
        self.supported_operator_restrictions.update(
            {op: self.check_vector_product_restrictions for op in self.fc_vector_products}
        )
        self.supported_operator_restrictions.update(
            {op: self.check_element_wise_restrictions for op in self.elem_wise_main_ops}
        )
        self.supported_operator_restrictions.update(
            {op: self.check_memory_only_restrictions for op in self.memory_only_ops}
        )
        self.supported_operator_restrictions.update({op: self.check_activation_ops for op in self.activation_ops})

    def is_operator_supported(self, op):
        if op.type not in self.supported_operators:
            return False
        if not self.check_generic_restrictions(op):
            return False
        if op.type in self.supported_operator_restrictions:
            return self.supported_operator_restrictions[op.type](op)
        return True

    def check_generic_restrictions(self, op):
        # check fully defined shapes
        for t in op.inputs:
            if not t:
                continue
            if not t.has_fully_defined_shape():
                print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU")
                return False
            if t.shape == [] and op.type not in self.shapeless_input_ops:
                print(
                    "Warning:",
                    op.type,
                    "has input(s) of shape [].",
                    "Scalar input or broadcasting is not supported for this operator,",
                    "placing on CPU",
                )
                return False
            if len(t.shape) > 4:
                print("Warning:", op.type, "has input(s) of unsupported shape", t.shape, "placing on CPU")
                return False
        for t in op.outputs:
            if not t.has_fully_defined_shape():
                print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU")
                return False
            if t.shape == []:
                print(
                    "Warning:",
                    op.type,
                    "has output(s) of shape [].",
                    "Scalar input or broadcasting is not supported for this operator,",
                    "placing on CPU",
                )
                return False
            if len(t.shape) > 4:
                print("Warning:", op.type, "has output(s) of unsupported shape", t.shape, "placing on CPU")
                return False

        # check data type
        tensors = [t for t in op.get_ifm_ifm2_weights_ofm() if t is not None]
        if not tensors:
            tensors = op.inputs
        for t in tensors:
            if not (t.dtype.type & BaseType.Int):
                return False
            if (
                t.element_size() > 2
                and op.type
                not in set(("Requantize", "ReduceSum", "CLZ",))
                | self.binary_elem_wise_add_mul_sub
                | self.binary_elem_wise_shift_ops
            ):
                return False
            # check size
            if any(dim > 65536 for dim in t.shape):
                return False

        # check fused activations
        if (
            "fused_activation_function" in op.attrs
            and op.attrs["fused_activation_function"] is not None
            and op.attrs["fused_activation_function"] not in self.supported_fused_activations
        ):
            return False

        # check inf values
        for tens in op.get_ifm_ifm2_weights_ofm():
            if (
                (tens is not None)
                and (tens.quantization is not None)
                and (tens.quantization.scale_f32 is not None)
                and (np.isinf(tens.quantization.scale_f32).any())
            ):
                print("Warning:", op.type, "has inf valued tensor(s), placing on CPU")
                return False

        return True

    def check_convolution_restrictions(self, op):
        # check stride
        if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
            return False

        # check dilation
        dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
        dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
        if dilation_w_factor > 2 or dilation_h_factor > 2:
            return False

        # check data type
        ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
        if weight_tensor.element_size() > 1:
            return False

        if not self.check_bias_restrictions(bias_tensor):
            return False

        # check kernel size [HWIO]
        dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1)
        dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1)

        if dilated_weight_w > 64 or dilated_weight_h > 64:
            return False

        # check non const weights
        if weight_tensor.values is None:
            print("Warning:", op.type, "has non-const weights, placing on CPU")
            return False

        # check weight sums over [HWI]
        zero_point = weight_tensor.quantization.zero_point
        quant_weights = weight_tensor.quant_values.astype(np.int64)
        weights = quant_weights - zero_point
        totals = np.sum(np.absolute(weights), axis=(0, 1, 2))

        if np.amax(totals) > 127 * 65536:
            return False

        # check batch size
        if ifm_tensor.shape[0] != 1:
            return False

        return True

    def check_depthwise_convolution_restrictions(self, op):
        # check depth
        ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
        if op.attrs["depth_multiplier"] > 1 and not (
            (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
        ):
            return False
        return self.check_convolution_restrictions(op)

    def check_transpose_convolution_restrictions(self, op):
        # check stride
        stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
        if stride_h != stride_w != 2:
            return False

        # check output dimensions
        ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
        ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
        ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
        if op.attrs["padding"] == b"SAME":
            if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
                return False
        elif op.attrs["padding"] == b"VALID":
            kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
            if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
                ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
            ):
                return False

        return self.check_convolution_restrictions(op)

    def check_pooling_restrictions(self, op):
        # check stride
        if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
            return False

        # check data type
        ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
        if ifm_tensor.dtype != ofm_tensor.dtype:
            if op.type != "ReduceSum":
                return False
            # TODO: else check ReduceSum restrictions.

        # check batch size
        if ifm_tensor.shape[0] != 1:
            return False

        if op.type in self.avg_pooling_ops:
            # check kernel size
            if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
                return False
            if op.attrs["padding"] == b"VALID" and (
                op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256
            ):
                return False

        if op.type in self.max_pooling_ops:
            # check kernel size (any padding)
            if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
                return False
        return True

    def check_resize_restrictions(self, op):
        # check unsupported upscaling factor
        if op.type == "ResizeBilinear":
            if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
                return True
            if op.inputs[0].shape == op.outputs[0].shape:
                return True
            upscaled_shape = np.array(op.inputs[0].shape[1:3])
            out_shape = np.array(op.outputs[0].shape[1:3])
            while (upscaled_shape < out_shape).all():
                upscaled_shape *= 2
                if op.attrs["align_corners"]:
                    upscaled_shape -= 1
                if np.array_equal(out_shape, upscaled_shape):
                    return True
        return False

    def check_vector_product_restrictions(self, op):
        # check data type
        _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
        if weight_tensor.element_size() > 1:
            return False

        if not self.check_bias_restrictions(bias_tensor):
            return False

        # check non const weights
        if weight_tensor.values is None:
            print("Warning:", op.type, "has non-const weights, placing on CPU")
            return False

        return True

    def check_element_wise_restrictions(self, op):
        # check data type
        ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
        # input and output datatype must match for these operators
        if (
            op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops
            and ifm_tensor.dtype != ofm_tensor.dtype
        ):
            return False
        if op.type in self.binary_elem_wise_add_mul_sub:
            # both inputs must have same type
            if ifm_tensor.dtype != ifm2_tensor.dtype:
                return False
            # signed input check
            if ifm_tensor.dtype.type & BaseType.Signed:
                # output must be signed
                if ofm_tensor.dtype.type & BaseType.Unsigned:
                    return False
                # and 8, 16 or 32-bit
                if ofm_tensor.element_size() not in (1, 2, 4):
                    return False
            # unsigned input check, output must be same type or int32
            if ifm_tensor.dtype.type & BaseType.Unsigned and not (
                ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
            ):
                return False
        elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")):
            if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
                return False
            if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
                return False

        # check batch size
        if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
            return False
        if op.type in self.binary_elem_wise_main_ops:  # if op type is unary, ifm2_tensor is None
            if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
                return False

        # negative alpha values are not supported
        if op.type == "LeakyRelu" and op.attrs["alpha"] < 0:
            return False

        # check if ifm or ifm2 has ofm shape
        if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
            return False

        if op.type in self.binary_elem_wise_min_max_ops and not self.check_quantization_restrictions_binary_elem_wise(
            op
        ):
            return False

        return True

    def check_memory_only_restrictions(self, op):
        if op.type == "StridedSlice":
            if len(op.inputs) != 4:
                warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
                return False
            input_tens, begin_tens, end_tens, strides_tens = op.inputs
            if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
                warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
                return False
            if not (
                len(input_tens.shape)
                == len(op.outputs[0].shape)
                == len(begin_tens.values)
                == len(end_tens.values)
                == len(strides_tens.values)
            ):
                warn_cpu(op, "has input tensors with shapes that are not supported")
                return False
            # check stride size
            if any(stride != 1 for stride in strides_tens.values):
                warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
                return False
            # check ellipsis_mask
            if op.attrs["ellipsis_mask"] != 0:
                warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
                return False
            # check if both new_axis_mask and shrink_axis_mask have bit set
            if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
                warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
                return False
            # Calculate offset start/end
            offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
            offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
            # check "end - begin" doesn't result in any zero or negative elements
            if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
                warn_cpu(
                    op,
                    "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
                        begin_tens.values, end_tens.values
                    ),
                )
                return False
        if op.type == "SplitV":
            # check that maximum one size is set to -1, indicating that size should be inferred
            sizes = op.inputs[1].values
            num_to_be_inferred = 0
            for size in sizes:
                if size == -1:
                    num_to_be_inferred += 1

            if num_to_be_inferred > 1:
                print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
                return False
        if op.type.find("Concat") != -1:
            axis = op.attrs.get("axis", None)
            if axis is None:
                print("Warning:", op.type, "invalid or missing axis, placing on CPU")
                return False
            if axis < 0:
                axis += len(op.inputs[0].shape)
            if not 0 < axis < len(op.inputs[0].shape):
                print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
                return False
            ofm = op.outputs[0]
            ofm_dims = len(ofm.shape)
            for ifm in op.inputs:
                if len(ifm.shape) != ofm_dims:
                    return False
                for i in range(ofm_dims):
                    if i != axis and ifm.shape[i] != ofm.shape[i]:
                        print(
                            "Warning:",
                            op.type,
                            "invalid ifm:",
                            ifm.name,
                            ifm.shape,
                            "mismatch in dimension",
                            i,
                            ", placing on CPU",
                        )
                        return False

        return True

    def check_quantization_restrictions_binary_elem_wise(self, op):
        # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
        assert len(op.inputs) >= 2 and len(op.outputs) == 1

        if (
            op.inputs[0].quantization is None
            or not op.inputs[0].is_scaling_equal(op.inputs[1])
            or not op.inputs[0].is_scaling_equal(op.outputs[0])
        ):
            print(
                "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
            )
            return False

        return True

    def check_activation_ops(self, op):
        if op.type == "Softmax":
            ifm_tensor = op.inputs[0]
            ofm_tensor = op.outputs[0]

            # check data type
            if ifm_tensor.dtype != ofm_tensor.dtype:
                return False

            if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
                return False

            # check shape
            if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape:
                return False

        return True

    def check_bias_restrictions(self, bias_tensor):
        # check data type
        if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
            return False

        # check if values fits in 40-bit
        if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
            for quant_value in bias_tensor.quant_values:
                if not (-(1 << 39) <= quant_value < (1 << 39)):
                    return False

        return True