diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-21 14:18:44 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-21 14:37:10 +0200 |
commit | 3f22ec2025c8e1afe6780785fd8c62c015824a63 (patch) | |
tree | b7d3324def750afc3a0f4806b195872069e08b62 /ethosu/vela/tosa_supported_operators.py | |
parent | 46408a8049f6a51dda5bfa8a4c9959e037120265 (diff) | |
download | ethos-u-vela-3f22ec2025c8e1afe6780785fd8c62c015824a63.tar.gz |
TOSA: Decompose elem op tensors
Added decomposition of tensors exceeding
maximum size supported by NPU.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I17a99cb72947d2f1064a631ad6975ce895c258d5
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r-- | ethosu/vela/tosa_supported_operators.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index f5eddccc..1012a615 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -117,18 +117,19 @@ class TosaSupportedOperators: # This is for a HW limitation, that is to be resolved in SW later on @classmethod @docstring_format_args(tens_dim_range) - def constraint_tens_dimension(cls, op): - "Tensor dimensions must be in the range [{}, {}]" - tens_min, tens_max = cls.tens_dim_range + def constraint_tens_dimension(self, op): + "Tensor dimensions must be in the range [{}, {}], if not elementwise" + tens_min, tens_max = self.tens_dim_range valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - if not all(tens_min <= dim <= tens_max for dim in tens.shape): - valid = False - extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + if not all(tens_min <= dim <= tens_max for dim in tens.shape): + valid = False + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on |