diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index d9ace1e..eff40bc 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -209,6 +209,10 @@ class TFLiteSemantic: # Exp specific checks self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed) + # Transpose specific checks + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size) + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -833,6 +837,24 @@ class TFLiteSemantic: extra = ", ".join(extra) return valid, f"Op has non-variable state tensor(s): {extra}" + @staticmethod + def constraint_transpose_permutation_size(op): + "Permutation array must be a 1D tensor with RANK(IFM) elements" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = len(perm.shape) == 1 and perm.shape[0] == dims + return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}" + + @staticmethod + def constraint_transpose_permutation_values(op): + "Permutation array must have constant values in the range [0, RANK(IFM))" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = False + if perm.values is not None: + valid = not any([val < 0 or val >= dims for val in perm.values]) + return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() |