aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 98df27e3..f5eddccc 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -40,15 +40,15 @@ class TosaSupportedOperators:
mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
+ elem_wise_ops = binary_elem_wise_add_mul_sub
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops | set((Op.Table,))
pad_ops = set((Op.Pad,))
npu_post_ops = activation_ops
- supported_operators = (
- mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops
- )
+
+ supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -132,35 +132,37 @@ class TosaSupportedOperators:
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_rank(op):
- "Tensor rank must be <= 4"
+ @classmethod
+ def constraint_rank(self, op):
+ "Tensor rank must be <= 4, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if not rank <= 4:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if not rank <= 4:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: {rank}")
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_batch(op):
- "If Tensor rank is 4 batch of ifms/ofm must be 1"
+ @classmethod
+ def constraint_batch(self, op):
+ "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if rank == 4 and tens.shape[0] != 1:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if rank == 4 and tens.shape[0] != 1:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
return valid, ", ".join(extra)
@staticmethod