aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py28
1 files changed, 19 insertions, 9 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 1012a615..d71e5750 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,6 +46,10 @@ class TosaSupportedOperators:
activation_ops = relu_ops | set((Op.Table,))
pad_ops = set((Op.Pad,))
+ rank_unlimited_ops = set((Op.Concat,))
+ rank6_limited_ops = elem_wise_ops
+ batch_enabled_ops = elem_wise_ops | set((Op.Concat,))
+ large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,))
npu_post_ops = activation_ops
supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
@@ -60,8 +64,10 @@ class TosaSupportedOperators:
self.generic_constraints = []
self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension) # TODO as not supported yet
- self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported yet
- self.generic_constraints.append(TosaSupportedOperators.constraint_batch) # TODO as not supported yet
+ self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported for all ops yet
+ self.generic_constraints.append(
+ TosaSupportedOperators.constraint_batch
+ ) # TODO as not supported for all ops yet
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
@@ -118,11 +124,11 @@ class TosaSupportedOperators:
@classmethod
@docstring_format_args(tens_dim_range)
def constraint_tens_dimension(self, op):
- "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+ "Tensor dimensions must be in the range [{}, {}]"
tens_min, tens_max = self.tens_dim_range
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.large_tens_dims_enabled_ops:
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]
@@ -135,16 +141,20 @@ class TosaSupportedOperators:
# TODO This is for a HW limitation, that is to be resolved in SW later on
@classmethod
def constraint_rank(self, op):
- "Tensor rank must be <= 4, if not elementwise"
+ "Tensor rank must be <= 6 or <= 4 depending on operator"
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.rank_unlimited_ops:
+ if op.type in self.rank6_limited_ops:
+ rank_limit = 6
+ else:
+ rank_limit = 4
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]
for tens in tensors:
rank = len(tens.shape)
- if not rank <= 4:
+ if not rank <= rank_limit:
valid = False
extra.append(f"Tensor '{tens.name}' has rank: {rank}")
return valid, ", ".join(extra)
@@ -152,10 +162,10 @@ class TosaSupportedOperators:
# TODO This is for a HW limitation, that is to be resolved in SW later on
@classmethod
def constraint_batch(self, op):
- "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
+ "If Tensor rank is 4 batch of ifms/ofm must be 1"
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.batch_enabled_ops:
tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]