diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-09-12 17:44:25 +0200 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-09-12 17:50:59 +0200 |
commit | b393251346c7a5eef1496bcc92c722ba92c73fd9 (patch) | |
tree | c9068b98fef8d77c3f351f12e7cbccea2ba7d52f | |
parent | b1872864b20588cc84280e5449158bb2168cc58b (diff) | |
download | ethos-u-vela-b393251346c7a5eef1496bcc92c722ba92c73fd9.tar.gz |
MLBEDSW-6863: Cleanup the constraint for concat
Removed duplicate code and moved constraint to
the correct file.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I2da3c5b88e1af351751c481217b8183b5948f0f8
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 21 | ||||
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 46 |
2 files changed, 21 insertions, 46 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index abda886c..7a0e234d 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -126,6 +126,7 @@ class TFLiteSemantic: self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid) self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality) self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions) + self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis) # Element-wise checks: for op_type in TFLiteSemantic.elem_wise_main_ops: @@ -447,6 +448,26 @@ class TFLiteSemantic: return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" @staticmethod + def constraint_valid_dimensions_axis(op): + """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute""" + valid = True + extra = [] + ofm_shape = op.ofm.shape + ofm_dim = len(ofm_shape) + axis = op.attrs["axis"] + axis += ofm_dim if axis < 0 else 0 + + sum_ifm_axis = 0 + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + sum_ifm_axis += tens.shape[axis] + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + + valid = sum_ifm_axis == ofm_shape[axis] + extra = ", ".join(extra) + return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" + + @staticmethod def constraint_stridedslice_input_count(op): "Exactly 4 Input tensors are required" inputs = len(op.inputs) diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 24cc26e0..be86e9a3 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -306,13 +306,6 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_before_mean) - # Concat specific checks: - for op_type in (Op.Concat, Op.ConcatTFLite): - self.specific_constraints[op_type].append( - TFLiteSupportedOperators.constraint_concat_valid_dimensions_non_axis - ) - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_concat_valid_dimensions_axis) - def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -874,42 +867,3 @@ class TFLiteSupportedOperators: if next_op is not None and next_op.type == Op.Mean: return False, "" return True, "" - - @staticmethod - def constraint_concat_valid_dimensions_non_axis(op): - """All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute""" - valid = True - extra = [] - ofm_shape = op.ofm.shape - ofm_dim = len(ofm_shape) - axis = op.attrs["axis"] - axis += ofm_dim if axis < 0 else 0 - - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis): - valid = False - extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") - - extra = ", ".join(extra) - return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" - - @staticmethod - def constraint_concat_valid_dimensions_axis(op): - """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute""" - valid = True - extra = [] - ofm_shape = op.ofm.shape - ofm_dim = len(ofm_shape) - axis = op.attrs["axis"] - axis += ofm_dim if axis < 0 else 0 - - sum_ifm_axis = 0 - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - sum_ifm_axis += tens.shape[axis] - extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") - - valid = sum_ifm_axis == ofm_shape[axis] - extra = ", ".join(extra) - return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" |