diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 10:47:47 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 13:27:15 +0200 |
commit | 46408a8049f6a51dda5bfa8a4c9959e037120265 (patch) | |
tree | 68595457843f3ff4193da0542afbe5de56da8d31 /ethosu/vela/tosa_supported_operators.py | |
parent | f436ada9caea87ec2dd686a92e41a15c1dcdeb1d (diff) | |
download | ethos-u-vela-46408a8049f6a51dda5bfa8a4c9959e037120265.tar.gz |
TOSA: Elementwise Rank > 4 and Batch > 1
Added support for elementwise operations:
-Support for up to Rank == 6
-Support for Batch > 1 for Rank == 4
-For binary elementwise ops this includes handling
of broadcasting in dimensions above H-dimension
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r-- | ethosu/vela/tosa_supported_operators.py | 52 |
1 files changed, 27 insertions, 25 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 98df27e3..f5eddccc 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -40,15 +40,15 @@ class TosaSupportedOperators: mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,)) binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,)) + elem_wise_ops = binary_elem_wise_add_mul_sub type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops | set((Op.Table,)) pad_ops = set((Op.Pad,)) npu_post_ops = activation_ops - supported_operators = ( - mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops - ) + + supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same @@ -132,35 +132,37 @@ class TosaSupportedOperators: return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_rank(op): - "Tensor rank must be <= 4" + @classmethod + def constraint_rank(self, op): + "Tensor rank must be <= 4, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if not rank <= 4: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: {rank}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if not rank <= 4: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: {rank}") return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_batch(op): - "If Tensor rank is 4 batch of ifms/ofm must be 1" + @classmethod + def constraint_batch(self, op): + "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if rank == 4 and tens.shape[0] != 1: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if rank == 4 and tens.shape[0] != 1: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") return valid, ", ".join(extra) @staticmethod |