aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 5661f36e..6ba7b835 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -193,6 +193,14 @@ class TFLiteSemantic:
self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
+ # UnidirectionalSequenceLstm specific checks:
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -628,6 +636,13 @@ class TFLiteSemantic:
return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
@staticmethod
+ def constraint_input_signed(op):
+ "IFM must be int8 or int16"
+ ifm_dtype = op.ifm.dtype
+ valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16)
+ return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+ @staticmethod
def constraint_input_8bit(op):
"IFM must be int8 or uint8"
ifm_dtype = op.ifm.dtype
@@ -689,6 +704,36 @@ class TFLiteSemantic:
return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
return True, "IFM and OFM number of elements are equal."
+ @staticmethod
+ def constraint_lstm_dimensions(op):
+ "IFM and OFM must have 3D shape"
+ valid = len(op.ifm.shape) == len(op.ofm.shape) == 3
+ return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}"
+
+ @staticmethod
+ def constraint_lstm_inputs(op):
+ "Must have 24 input tensors"
+ n_inputs = len(op.inputs)
+ return n_inputs == 24, f"Op has {n_inputs} inputs"
+
+ @staticmethod
+ def constraint_lstm_intermediates(op):
+ "Must have 5 intermediate tensors"
+ n_intermediates = len(op.intermediates)
+ return n_intermediates == 5, f"Op has {n_intermediates} intermediates"
+
+ @staticmethod
+ def constraint_lstm_variables(op):
+ "State tensors must be variable"
+ valid = True
+ extra = []
+ for tens in op.inputs[18:20]:
+ if not tens.is_variable:
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has non-variable state tensor(s): {extra}"
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()