diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 5661f36e..6ba7b835 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -193,6 +193,14 @@ class TFLiteSemantic: self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit) self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output) + # UnidirectionalSequenceLstm specific checks: + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -628,6 +636,13 @@ class TFLiteSemantic: return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}" @staticmethod + def constraint_input_signed(op): + "IFM must be int8 or int16" + ifm_dtype = op.ifm.dtype + valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16) + return valid, f"Op has ifm_dtype={ifm_dtype}" + + @staticmethod def constraint_input_8bit(op): "IFM must be int8 or uint8" ifm_dtype = op.ifm.dtype @@ -689,6 +704,36 @@ class TFLiteSemantic: return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal." return True, "IFM and OFM number of elements are equal." + @staticmethod + def constraint_lstm_dimensions(op): + "IFM and OFM must have 3D shape" + valid = len(op.ifm.shape) == len(op.ofm.shape) == 3 + return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}" + + @staticmethod + def constraint_lstm_inputs(op): + "Must have 24 input tensors" + n_inputs = len(op.inputs) + return n_inputs == 24, f"Op has {n_inputs} inputs" + + @staticmethod + def constraint_lstm_intermediates(op): + "Must have 5 intermediate tensors" + n_intermediates = len(op.intermediates) + return n_intermediates == 5, f"Op has {n_intermediates} intermediates" + + @staticmethod + def constraint_lstm_variables(op): + "State tensors must be variable" + valid = True + extra = [] + for tens in op.inputs[18:20]: + if not tens.is_variable: + valid = False + extra.append(tens.name) + extra = ", ".join(extra) + return valid, f"Op has non-variable state tensor(s): {extra}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() |