aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 14c2213..4500391 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -109,7 +109,9 @@ class TFLiteSupportedOperators:
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops | set((Op.SquaredDifference,))
pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
- set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+ set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax, Op.Transpose))
+ | binary_elem_wise_add_mul_sub
+ | binary_elem_wise_shift_ops
)
relu_ops = set(
@@ -163,6 +165,7 @@ class TFLiteSupportedOperators:
Op.QuantizedReshape,
Op.Squeeze,
Op.ExpandDims,
+ Op.Transpose,
)
)
| concat_ops
@@ -340,6 +343,9 @@ class TFLiteSupportedOperators:
# Slice specific checks:
self.specific_constraints[Op.Slice].append(TFLiteSupportedOperators.constraint_slice_inputs_const)
+ # Transpose specific checks:
+ self.specific_constraints[Op.Transpose].append(TFLiteSupportedOperators.constraint_transpose)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -1027,3 +1033,42 @@ class TFLiteSupportedOperators:
extra.append(f"Size tensor '{sizes.name}'")
extra = ", ".join(extra)
return valid, f"Op has non-constant tensors: {extra}"
+
+ @staticmethod
+ def constraint_transpose(op):
+ """The following shape/permutations are supported for transpose:
+ When ifm rank is 2: WxC -> CxW
+ When ifm rank is 3: HxWxC -> WxHxC, 1xWxC -> 1xCxW, Hx1xC -> Cx1xH
+ When ifm rank is 4: 1xHxWxC -> 1xWxHxC, 1x1xWxC -> 1x1xCxW, 1xHx1xC -> 1xCx1xW"""
+
+ ifm_shape = op.inputs[0].shape
+ perm = op.inputs[1]
+
+ # WxC -> CxW
+ valid = len(ifm_shape) == 2
+
+ # HxWxC -> WxHxC
+ if not valid and perm.shape == [3]:
+ valid = perm.values[0] == 1 and perm.values[1] == 0
+
+ # 1xWxC -> 1xCxW
+ if not valid and perm.shape == [3] and ifm_shape[0] == 1:
+ valid = perm.values[1] == 2 and perm.values[2] == 1
+
+ # Hx1xC -> Cx1xH
+ if not valid and perm.shape == [3] and ifm_shape[1] == 1:
+ valid = perm.values[0] == 2 and perm.values[2] == 0
+
+ # 1xHxWxC -> 1xWxHxC
+ if not valid and perm.shape == [4]:
+ valid = perm.values[0] == 0 and perm.values[1] == 2 and perm.values[2] == 1
+
+ # 1x1xWxC -> 1x1xCxW
+ if not valid and perm.shape == [4] and ifm_shape[1] == 1:
+ valid = perm.values[0] == 0 and perm.values[2] == 3 and perm.values[3] == 2
+
+ # 1xHx1xC -> 1xCx1xH
+ if not valid and perm.shape == [4] and ifm_shape[2] == 1:
+ valid = perm.values[0] == 0 and perm.values[1] == 3 and perm.values[3] == 1
+
+ return valid, f"Op has ifm_shape: {ifm_shape} and permutation is: {perm.values}"