aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-08-23 15:33:59 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-09-03 12:19:48 +0000
commitdf99510f04aef99d1b8e9be9bfcde8fc1738b65f (patch)
tree00668b0e74f95da5cc51a41b9340d8c88fbc7ffe /ethosu/vela/tosa_supported_operators.py
parentcce872bc3de3ed5f9bf1aa1a8cf9ce41cf2b2520 (diff)
downloadethos-u-vela-df99510f04aef99d1b8e9be9bfcde8fc1738b65f.tar.gz
TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d with dephmultiplier = 1. (But there are no testcases suited, all I have sourced has depth_multiplier set to 2, which is not supported.) -Added support for depthwise conv2d. -Added support for removing Transpose of constant data -Added support for removing reshape Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 3b0e6b39..90d54687 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -28,19 +28,24 @@ class TosaSupportedOperators:
# TODO currently sparsely populated
# Categorised lists of supported operators
convolution_ops = set((Op.Conv2DBias,))
- convolution_like_ops = convolution_ops
+ depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
+ convolution_like_ops = convolution_ops | depthwise_convolution_ops
+
+ # TODO depending on what will be committed
max_pooling_ops = Op.op_set(Op.is_maxpool_op)
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
- pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+ pooling_ops = max_pooling_ops | avg_pooling_ops
+ fc_vector_products = set((Op.FullyConnected,))
- mac_main_ops = convolution_like_ops | pooling_ops
+ mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
+ memory_only_ops = set((Op.Reshape, Op.Transpose,))
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops
npu_post_ops = activation_ops
- supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops
+ supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -54,6 +59,12 @@ class TosaSupportedOperators:
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
+ self.specific_constraints[Op.Transpose].append(TosaSupportedOperators.constraint_ifm_producer)
+
+ # Depthwise Conv specific checks:
+ for op_type in TosaSupportedOperators.depthwise_convolution_ops:
+ self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_depth_multiplier)
+
def is_operator_supported(self, op):
ext_type = optype_to_tosa_op_type(op.type)
if op.type not in TosaSupportedOperators.supported_operators:
@@ -87,3 +98,25 @@ class TosaSupportedOperators:
valid = False
extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_ifm_producer(cls, op):
+ "Input must be constant data"
+ valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const
+ return valid, "Op has ifm with non-constant data"
+
+ # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage
+ @staticmethod
+ def constraint_depth_multiplier(op):
+ "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"
+ depth_multiplier = op.attrs.get("depth_multiplier", 1)
+ if depth_multiplier > 1:
+ ifm_channels = op.ifm.shape[3]
+ ofm_channels = op.ofm.shape[3]
+ valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier)
+ extra = (
+ f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}"
+ f" and depth_multiplier={depth_multiplier}"
+ )
+ return valid, extra
+ return True, "Op has depth_multiplier=1"