aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25f19b77..457c35eb 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -69,8 +69,8 @@ class TFLiteSupportedOperators:
)
)
mac_main_ops = (
- # RNN/LSTM/GRU
- set((Op.BlockLSTM,))
+ # LSTM
+ set((Op.UnidirectionalSequenceLstm,))
# conv/depthwiseconv/transposeconv
| convolution_like_ops
# pooling
@@ -320,6 +320,14 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+ # UnidirectionalSequenceLstm specific checks:
+ op_type = Op.UnidirectionalSequenceLstm
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -888,3 +896,35 @@ class TFLiteSupportedOperators:
"IFM depth must be no greater than 127"
ifm_depth = op.inputs[0].shape[-1]
return ifm_depth <= 127, f"IFM depth is {ifm_depth}"
+
+ @staticmethod
+ def constraint_lstm_no_cifg(op):
+ "Must not use CIFG"
+ cifg = None not in op.inputs[2:5] + op.inputs[6:9]
+ cifg = cifg and op.inputs[1] is None
+ cifg = cifg and op.inputs[5] is None
+ return not cifg, "Op uses CIFG"
+
+ @staticmethod
+ def constraint_lstm_no_peep_hole(op):
+ "Must not use Peephole"
+ valid = all([tens is None for tens in op.inputs[9:12]])
+ return valid, "Op uses peephole"
+
+ @staticmethod
+ def constraint_lstm_no_projection(op):
+ "Must not use Projection"
+ valid = all([tens is None for tens in op.inputs[16:18]])
+ return valid, "Op uses projection"
+
+ @staticmethod
+ def constraint_lstm_no_normalisation(op):
+ "Must not use Normalisation"
+ valid = all([tens is None for tens in op.inputs[20:24]])
+ return valid, "Op uses normalisation"
+
+ @staticmethod
+ def constraint_lstm_weights(op):
+ "All input and recurrent weights must be available"
+ valid = None not in op.inputs[1:9]
+ return valid, "Op has missing weights"