diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 25f19b77..457c35eb 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -69,8 +69,8 @@ class TFLiteSupportedOperators: ) ) mac_main_ops = ( - # RNN/LSTM/GRU - set((Op.BlockLSTM,)) + # LSTM + set((Op.UnidirectionalSequenceLstm,)) # conv/depthwiseconv/transposeconv | convolution_like_ops # pooling @@ -320,6 +320,14 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth) + # UnidirectionalSequenceLstm specific checks: + op_type = Op.UnidirectionalSequenceLstm + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -888,3 +896,35 @@ class TFLiteSupportedOperators: "IFM depth must be no greater than 127" ifm_depth = op.inputs[0].shape[-1] return ifm_depth <= 127, f"IFM depth is {ifm_depth}" + + @staticmethod + def constraint_lstm_no_cifg(op): + "Must not use CIFG" + cifg = None not in op.inputs[2:5] + op.inputs[6:9] + cifg = cifg and op.inputs[1] is None + cifg = cifg and op.inputs[5] is None + return not cifg, "Op uses CIFG" + + @staticmethod + def constraint_lstm_no_peep_hole(op): + "Must not use Peephole" + valid = all([tens is None for tens in op.inputs[9:12]]) + return valid, "Op uses peephole" + + @staticmethod + def constraint_lstm_no_projection(op): + "Must not use Projection" + valid = all([tens is None for tens in op.inputs[16:18]]) + return valid, "Op uses projection" + + @staticmethod + def constraint_lstm_no_normalisation(op): + "Must not use Normalisation" + valid = all([tens is None for tens in op.inputs[20:24]]) + return valid, "Op uses normalisation" + + @staticmethod + def constraint_lstm_weights(op): + "All input and recurrent weights must be available" + valid = None not in op.inputs[1:9] + return valid, "Op has missing weights" |