diff options
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r-- | ethosu/vela/tosa_supported_operators.py | 52 |
1 files changed, 27 insertions, 25 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 98df27e3..f5eddccc 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -40,15 +40,15 @@ class TosaSupportedOperators: mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,)) binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,)) + elem_wise_ops = binary_elem_wise_add_mul_sub type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops | set((Op.Table,)) pad_ops = set((Op.Pad,)) npu_post_ops = activation_ops - supported_operators = ( - mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops - ) + + supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same @@ -132,35 +132,37 @@ class TosaSupportedOperators: return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_rank(op): - "Tensor rank must be <= 4" + @classmethod + def constraint_rank(self, op): + "Tensor rank must be <= 4, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if not rank <= 4: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: {rank}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if not rank <= 4: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: {rank}") return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_batch(op): - "If Tensor rank is 4 batch of ifms/ofm must be 1" + @classmethod + def constraint_batch(self, op): + "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if rank == 4 and tens.shape[0] != 1: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if rank == 4 and tens.shape[0] != 1: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") return valid, ", ".join(extra) @staticmethod |