aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 7d544004..723c5f22 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -331,6 +331,7 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weight_dimensions)
# Rsqrt specific checks
self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8)
@@ -967,6 +968,12 @@ class TFLiteSupportedOperators:
return valid, "Op has missing weights"
@staticmethod
+ def constraint_lstm_weight_dimensions(op):
+ "All recurrent weights must be 2D"
+ valid = all([len(input.shape) == 2 for input in op.inputs[5:9]])
+ return valid, "Op recurrent weights are not 2D"
+
+ @staticmethod
def constraint_rsqrt_input_int8(op):
"IFM must be int8"
ifm_dtype = op.ifm.dtype