diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 7d544004..723c5f22 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -331,6 +331,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weight_dimensions) # Rsqrt specific checks self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8) @@ -967,6 +968,12 @@ class TFLiteSupportedOperators: return valid, "Op has missing weights" @staticmethod + def constraint_lstm_weight_dimensions(op): + "All recurrent weights must be 2D" + valid = all([len(input.shape) == 2 for input in op.inputs[5:9]]) + return valid, "Op recurrent weights are not 2D" + + @staticmethod def constraint_rsqrt_input_int8(op): "IFM must be int8" ifm_dtype = op.ifm.dtype |