diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 14c22133..45003913 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -109,7 +109,9 @@ class TFLiteSupportedOperators: elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops | set((Op.SquaredDifference,)) pad_ops = set((Op.Pad,)) supported_int32_tensor_ops = ( - set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops + set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax, Op.Transpose)) + | binary_elem_wise_add_mul_sub + | binary_elem_wise_shift_ops ) relu_ops = set( @@ -163,6 +165,7 @@ class TFLiteSupportedOperators: Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims, + Op.Transpose, ) ) | concat_ops @@ -340,6 +343,9 @@ class TFLiteSupportedOperators: # Slice specific checks: self.specific_constraints[Op.Slice].append(TFLiteSupportedOperators.constraint_slice_inputs_const) + # Transpose specific checks: + self.specific_constraints[Op.Transpose].append(TFLiteSupportedOperators.constraint_transpose) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -1027,3 +1033,42 @@ class TFLiteSupportedOperators: extra.append(f"Size tensor '{sizes.name}'") extra = ", ".join(extra) return valid, f"Op has non-constant tensors: {extra}" + + @staticmethod + def constraint_transpose(op): + """The following shape/permutations are supported for transpose: + When ifm rank is 2: WxC -> CxW + When ifm rank is 3: HxWxC -> WxHxC, 1xWxC -> 1xCxW, Hx1xC -> Cx1xH + When ifm rank is 4: 1xHxWxC -> 1xWxHxC, 1x1xWxC -> 1x1xCxW, 1xHx1xC -> 1xCx1xW""" + + ifm_shape = op.inputs[0].shape + perm = op.inputs[1] + + # WxC -> CxW + valid = len(ifm_shape) == 2 + + # HxWxC -> WxHxC + if not valid and perm.shape == [3]: + valid = perm.values[0] == 1 and perm.values[1] == 0 + + # 1xWxC -> 1xCxW + if not valid and perm.shape == [3] and ifm_shape[0] == 1: + valid = perm.values[1] == 2 and perm.values[2] == 1 + + # Hx1xC -> Cx1xH + if not valid and perm.shape == [3] and ifm_shape[1] == 1: + valid = perm.values[0] == 2 and perm.values[2] == 0 + + # 1xHxWxC -> 1xWxHxC + if not valid and perm.shape == [4]: + valid = perm.values[0] == 0 and perm.values[1] == 2 and perm.values[2] == 1 + + # 1x1xWxC -> 1x1xCxW + if not valid and perm.shape == [4] and ifm_shape[1] == 1: + valid = perm.values[0] == 0 and perm.values[2] == 3 and perm.values[3] == 2 + + # 1xHx1xC -> 1xCx1xH + if not valid and perm.shape == [4] and ifm_shape[2] == 1: + valid = perm.values[0] == 0 and perm.values[1] == 3 and perm.values[3] == 1 + + return valid, f"Op has ifm_shape: {ifm_shape} and permutation is: {perm.values}" |