aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/tflite_supported_operators.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 90d93d0f..5d25e37b 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -297,6 +297,13 @@ class TFLiteSupportedOperators:
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
+ # Concat specific checks:
+ for op_type in (Op.Concat, Op.ConcatTFLite):
+ self.specific_constraints[op_type].append(
+ TFLiteSupportedOperators.constraint_concat_valid_dimensions_non_axis
+ )
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_concat_valid_dimensions_axis)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -850,3 +857,42 @@ class TFLiteSupportedOperators:
extra = ", ".join(extra)
return valid, f"Op has non-const input(s): {extra}"
+
+ @staticmethod
+ def constraint_concat_valid_dimensions_non_axis(op):
+ """All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"""
+ valid = True
+ extra = []
+ ofm_shape = op.ofm.shape
+ ofm_dim = len(ofm_shape)
+ axis = op.attrs["axis"]
+ axis += ofm_dim if axis < 0 else 0
+
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+
+ extra = ", ".join(extra)
+ return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
+
+ @staticmethod
+ def constraint_concat_valid_dimensions_axis(op):
+ """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
+ valid = True
+ extra = []
+ ofm_shape = op.ofm.shape
+ ofm_dim = len(ofm_shape)
+ axis = op.attrs["axis"]
+ axis += ofm_dim if axis < 0 else 0
+
+ sum_ifm_axis = 0
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ sum_ifm_axis += tens.shape[axis]
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+
+ valid = sum_ifm_axis == ofm_shape[axis]
+ extra = ", ".join(extra)
+ return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"