diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-08-16 13:04:17 +0200 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-08-17 09:08:56 +0200 |
commit | 8e1352a00dcc0198ae2cd0d8380ef560bd3a847c (patch) | |
tree | c9e5b5a4791bb2c19e2cc068ef4b01021cbb42c7 | |
parent | 21b064c0ddc8127178aa733e1976001f9cd2a1a7 (diff) | |
download | ethos-u-vela-8e1352a00dcc0198ae2cd0d8380ef560bd3a847c.tar.gz |
MLBEDSW-6830: MLCE: Fix assert on concat op
- The compiler will assert when compiling a faulty concat op.
In the reported use case, there were 3 inputs with shape 1x1x2
but the output shape was 1x1x2 (expected to be 1x1x6)
- The solution is to add constraints to the concat operator.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I94a505c51a9fd54d1aa92531a0415031db52378a
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 90d93d0f..5d25e37b 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -297,6 +297,13 @@ class TFLiteSupportedOperators: # Reshape specific checks: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) + # Concat specific checks: + for op_type in (Op.Concat, Op.ConcatTFLite): + self.specific_constraints[op_type].append( + TFLiteSupportedOperators.constraint_concat_valid_dimensions_non_axis + ) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_concat_valid_dimensions_axis) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -850,3 +857,42 @@ class TFLiteSupportedOperators: extra = ", ".join(extra) return valid, f"Op has non-const input(s): {extra}" + + @staticmethod + def constraint_concat_valid_dimensions_non_axis(op): + """All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute""" + valid = True + extra = [] + ofm_shape = op.ofm.shape + ofm_dim = len(ofm_shape) + axis = op.attrs["axis"] + axis += ofm_dim if axis < 0 else 0 + + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis): + valid = False + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + + extra = ", ".join(extra) + return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" + + @staticmethod + def constraint_concat_valid_dimensions_axis(op): + """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute""" + valid = True + extra = [] + ofm_shape = op.ofm.shape + ofm_dim = len(ofm_shape) + axis = op.attrs["axis"] + axis += ofm_dim if axis < 0 else 0 + + sum_ifm_axis = 0 + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + sum_ifm_axis += tens.shape[axis] + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + + valid = sum_ifm_axis == ofm_shape[axis] + extra = ", ".join(extra) + return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" |