diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index abda886c..7a0e234d 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -126,6 +126,7 @@ class TFLiteSemantic: self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid) self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality) self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions) + self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis) # Element-wise checks: for op_type in TFLiteSemantic.elem_wise_main_ops: @@ -447,6 +448,26 @@ class TFLiteSemantic: return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" @staticmethod + def constraint_valid_dimensions_axis(op): + """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute""" + valid = True + extra = [] + ofm_shape = op.ofm.shape + ofm_dim = len(ofm_shape) + axis = op.attrs["axis"] + axis += ofm_dim if axis < 0 else 0 + + sum_ifm_axis = 0 + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + sum_ifm_axis += tens.shape[axis] + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + + valid = sum_ifm_axis == ofm_shape[axis] + extra = ", ".join(extra) + return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}" + + @staticmethod def constraint_stridedslice_input_count(op): "Exactly 4 Input tensors are required" inputs = len(op.inputs) |