diff options
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r-- | ethosu/vela/tosa_supported_operators.py | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 3b0e6b39..90d54687 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -28,19 +28,24 @@ class TosaSupportedOperators: # TODO currently sparsely populated # Categorised lists of supported operators convolution_ops = set((Op.Conv2DBias,)) - convolution_like_ops = convolution_ops + depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,)) + convolution_like_ops = convolution_ops | depthwise_convolution_ops + + # TODO depending on what will be committed max_pooling_ops = Op.op_set(Op.is_maxpool_op) avg_pooling_ops = Op.op_set(Op.is_avgpool_op) - pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops + pooling_ops = max_pooling_ops | avg_pooling_ops + fc_vector_products = set((Op.FullyConnected,)) - mac_main_ops = convolution_like_ops | pooling_ops + mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products + memory_only_ops = set((Op.Reshape, Op.Transpose,)) type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops npu_post_ops = activation_ops - supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops + supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same @@ -54,6 +59,12 @@ class TosaSupportedOperators: # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) + self.specific_constraints[Op.Transpose].append(TosaSupportedOperators.constraint_ifm_producer) + + # Depthwise Conv specific checks: + for op_type in TosaSupportedOperators.depthwise_convolution_ops: + self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_depth_multiplier) + def is_operator_supported(self, op): ext_type = optype_to_tosa_op_type(op.type) if op.type not in TosaSupportedOperators.supported_operators: @@ -87,3 +98,25 @@ class TosaSupportedOperators: valid = False extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}") return valid, ", ".join(extra) + + @staticmethod + def constraint_ifm_producer(cls, op): + "Input must be constant data" + valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const + return valid, "Op has ifm with non-constant data" + + # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage + @staticmethod + def constraint_depth_multiplier(op): + "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier" + depth_multiplier = op.attrs.get("depth_multiplier", 1) + if depth_multiplier > 1: + ifm_channels = op.ifm.shape[3] + ofm_channels = op.ofm.shape[3] + valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier) + extra = ( + f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}" + f" and depth_multiplier={depth_multiplier}" + ) + return valid, extra + return True, "Op has depth_multiplier=1" |