aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-10-28 16:04:46 +0200
committerJohan Alfven <johan.alfven@arm.com>2023-11-09 11:59:58 +0100
commita8fda88bced0d11441467b6798885101d41ac8e9 (patch)
tree807de7fa4eee48720255fbed4a605218f8612f6a /ethosu/vela/tflite_supported_operators.py
parent4bf0cdf58416edc030ae7507ace95224785e4aa8 (diff)
downloadethos-u-vela-a8fda88bced0d11441467b6798885101d41ac8e9.tar.gz
MLBEDSW-8290: MLCE: Add TRANSPOSE support3.10.0.rc1
- Added graph optimiser function to convert TRANSPOSE op into an AvgPool op with swapped stride for height and width - Added TRANSPOSE supported op check - Added unit tests for TRANSPOSE supported op check - Updated SUPPORTED_OPS.md - Fixed problem in pass packing when optimizing the pass list. Old problem, but now seen when moving TRANSPOSE from cpu. Change-Id: I0a0ef420b0fb8241090c2e2434622881105cde15 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 14c22133..45003913 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -109,7 +109,9 @@ class TFLiteSupportedOperators:
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops | set((Op.SquaredDifference,))
pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
- set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+ set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax, Op.Transpose))
+ | binary_elem_wise_add_mul_sub
+ | binary_elem_wise_shift_ops
)
relu_ops = set(
@@ -163,6 +165,7 @@ class TFLiteSupportedOperators:
Op.QuantizedReshape,
Op.Squeeze,
Op.ExpandDims,
+ Op.Transpose,
)
)
| concat_ops
@@ -340,6 +343,9 @@ class TFLiteSupportedOperators:
# Slice specific checks:
self.specific_constraints[Op.Slice].append(TFLiteSupportedOperators.constraint_slice_inputs_const)
+ # Transpose specific checks:
+ self.specific_constraints[Op.Transpose].append(TFLiteSupportedOperators.constraint_transpose)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -1027,3 +1033,42 @@ class TFLiteSupportedOperators:
extra.append(f"Size tensor '{sizes.name}'")
extra = ", ".join(extra)
return valid, f"Op has non-constant tensors: {extra}"
+
+ @staticmethod
+ def constraint_transpose(op):
+ """The following shape/permutations are supported for transpose:
+ When ifm rank is 2: WxC -> CxW
+ When ifm rank is 3: HxWxC -> WxHxC, 1xWxC -> 1xCxW, Hx1xC -> Cx1xH
+ When ifm rank is 4: 1xHxWxC -> 1xWxHxC, 1x1xWxC -> 1x1xCxW, 1xHx1xC -> 1xCx1xW"""
+
+ ifm_shape = op.inputs[0].shape
+ perm = op.inputs[1]
+
+ # WxC -> CxW
+ valid = len(ifm_shape) == 2
+
+ # HxWxC -> WxHxC
+ if not valid and perm.shape == [3]:
+ valid = perm.values[0] == 1 and perm.values[1] == 0
+
+ # 1xWxC -> 1xCxW
+ if not valid and perm.shape == [3] and ifm_shape[0] == 1:
+ valid = perm.values[1] == 2 and perm.values[2] == 1
+
+ # Hx1xC -> Cx1xH
+ if not valid and perm.shape == [3] and ifm_shape[1] == 1:
+ valid = perm.values[0] == 2 and perm.values[2] == 0
+
+ # 1xHxWxC -> 1xWxHxC
+ if not valid and perm.shape == [4]:
+ valid = perm.values[0] == 0 and perm.values[1] == 2 and perm.values[2] == 1
+
+ # 1x1xWxC -> 1x1xCxW
+ if not valid and perm.shape == [4] and ifm_shape[1] == 1:
+ valid = perm.values[0] == 0 and perm.values[2] == 3 and perm.values[3] == 2
+
+ # 1xHx1xC -> 1xCx1xH
+ if not valid and perm.shape == [4] and ifm_shape[2] == 1:
+ valid = perm.values[0] == 0 and perm.values[1] == 3 and perm.values[3] == 1
+
+ return valid, f"Op has ifm_shape: {ifm_shape} and permutation is: {perm.values}"