aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-21 14:18:44 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-21 14:37:10 +0200
commit3f22ec2025c8e1afe6780785fd8c62c015824a63 (patch)
treeb7d3324def750afc3a0f4806b195872069e08b62 /ethosu/vela/tosa_supported_operators.py
parent46408a8049f6a51dda5bfa8a4c9959e037120265 (diff)
downloadethos-u-vela-3f22ec2025c8e1afe6780785fd8c62c015824a63.tar.gz
TOSA: Decompose elem op tensors
Added decomposition of tensors exceeding maximum size supported by NPU. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I17a99cb72947d2f1064a631ad6975ce895c258d5
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index f5eddccc..1012a615 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -117,18 +117,19 @@ class TosaSupportedOperators:
# This is for a HW limitation, that is to be resolved in SW later on
@classmethod
@docstring_format_args(tens_dim_range)
- def constraint_tens_dimension(cls, op):
- "Tensor dimensions must be in the range [{}, {}]"
- tens_min, tens_max = cls.tens_dim_range
+ def constraint_tens_dimension(self, op):
+ "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+ tens_min, tens_max = self.tens_dim_range
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- if not all(tens_min <= dim <= tens_max for dim in tens.shape):
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if not all(tens_min <= dim <= tens_max for dim in tens.shape):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on