aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-20 10:47:47 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-20 13:27:15 +0200
commit46408a8049f6a51dda5bfa8a4c9959e037120265 (patch)
tree68595457843f3ff4193da0542afbe5de56da8d31 /ethosu/vela/tosa_supported_operators.py
parentf436ada9caea87ec2dd686a92e41a15c1dcdeb1d (diff)
downloadethos-u-vela-46408a8049f6a51dda5bfa8a4c9959e037120265.tar.gz
TOSA: Elementwise Rank > 4 and Batch > 1
Added support for elementwise operations: -Support for up to Rank == 6 -Support for Batch > 1 for Rank == 4 -For binary elementwise ops this includes handling of broadcasting in dimensions above H-dimension Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 98df27e3..f5eddccc 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -40,15 +40,15 @@ class TosaSupportedOperators:
mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
+ elem_wise_ops = binary_elem_wise_add_mul_sub
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops | set((Op.Table,))
pad_ops = set((Op.Pad,))
npu_post_ops = activation_ops
- supported_operators = (
- mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops
- )
+
+ supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -132,35 +132,37 @@ class TosaSupportedOperators:
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_rank(op):
- "Tensor rank must be <= 4"
+ @classmethod
+ def constraint_rank(self, op):
+ "Tensor rank must be <= 4, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if not rank <= 4:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if not rank <= 4:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: {rank}")
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_batch(op):
- "If Tensor rank is 4 batch of ifms/ofm must be 1"
+ @classmethod
+ def constraint_batch(self, op):
+ "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if rank == 4 and tens.shape[0] != 1:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if rank == 4 and tens.shape[0] != 1:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
return valid, ", ".join(extra)
@staticmethod