aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWilliam Isaksson <william.isaksson@arm.com>2023-07-17 13:03:09 +0000
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-07-24 13:09:58 +0000
commit2f9b6874a227d8fa056c2e2fd01e8c80824ee0bc (patch)
treeebde7518ef8380ce4e250610528b4030d140e12a
parent9cf63a3612491198a39f6bd1f4a587589b3ac20a (diff)
downloadethos-u-vela-2f9b6874a227d8fa056c2e2fd01e8c80824ee0bc.tar.gz
MLBEDSW-7165: Update to TensorFlow 2.12
- Updated FlatBuffers files using TensorFlow 2.12.0 schema - Added restriction for UnidirectionalSequenceLSTM to have 2D recurrent weights to handle that diagonal_recurrent_tensors attr is not currently supported. Change-Id: I104fd1f52485b9b83d644772dbcdeea2d17585f0 Signed-off-by: William Isaksson <william.isaksson@arm.com>
-rw-r--r--README.md3
-rw-r--r--ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py12
-rw-r--r--ethosu/vela/tflite_mapping.py9
-rw-r--r--ethosu/vela/tflite_supported_operators.py7
4 files changed, 28 insertions, 3 deletions
diff --git a/README.md b/README.md
index 5e3c419..8b9f35e 100644
--- a/README.md
+++ b/README.md
@@ -45,7 +45,8 @@ The tool has limited functionality for compiling a
(EXPERIMENTAL).
## TensorFlow Support
-* Vela 3.8.0 to current supports TensorFlow 2.11
+* Vela 3.9.0 to current supports TensorFlow 2.12
+* Vela 3.8.0 supports TensorFlow 2.11
* Vela 3.6.0 to 3.7.0 supports TensorFlow 2.10
* Vela 3.5.0 supports TensorFlow 2.9
* Vela 3.4.0 supports TensorFlow 2.8
diff --git a/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py b/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py
index 16641a7..73d298d 100644
--- a/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py
+++ b/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py
@@ -63,7 +63,14 @@ class UnidirectionalSequenceLSTMOptions(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
-def UnidirectionalSequenceLSTMOptionsStart(builder): builder.StartObject(5)
+ # UnidirectionalSequenceLSTMOptions
+ def DiagonalRecurrentTensors(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+def UnidirectionalSequenceLSTMOptionsStart(builder): builder.StartObject(6)
def Start(builder):
return UnidirectionalSequenceLSTMOptionsStart(builder)
def UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction(builder, fusedActivationFunction): builder.PrependInt8Slot(0, fusedActivationFunction, 0)
@@ -81,6 +88,9 @@ def AddTimeMajor(builder, timeMajor):
def UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): builder.PrependBoolSlot(4, asymmetricQuantizeInputs, 0)
def AddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs):
return UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs)
+def UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors): builder.PrependBoolSlot(5, diagonalRecurrentTensors, 0)
+def AddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors):
+ return UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors)
def UnidirectionalSequenceLSTMOptionsEnd(builder): return builder.EndObject()
def End(builder):
return UnidirectionalSequenceLSTMOptionsEnd(builder) \ No newline at end of file
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 83c55bb..dba30c4 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -716,7 +716,14 @@ builtin_operator_map = {
Op.UnidirectionalSequenceLstm,
OptionsSerializer(
"UnidirectionalSequenceLSTMOptions",
- ("asymmetric_quantize_inputs", "cell_clip", fused_act, "proj_clip", "time_major"),
+ (
+ "asymmetric_quantize_inputs",
+ "cell_clip",
+ "diagonal_recurrent_tensors",
+ fused_act,
+ "proj_clip",
+ "time_major",
+ ),
),
TFLITE_IFM_WEIGHTS_INDICES,
),
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 7d54400..723c5f2 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -331,6 +331,7 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weight_dimensions)
# Rsqrt specific checks
self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8)
@@ -967,6 +968,12 @@ class TFLiteSupportedOperators:
return valid, "Op has missing weights"
@staticmethod
+ def constraint_lstm_weight_dimensions(op):
+ "All recurrent weights must be 2D"
+ valid = all([len(input.shape) == 2 for input in op.inputs[5:9]])
+ return valid, "Op recurrent weights are not 2D"
+
+ @staticmethod
def constraint_rsqrt_input_int8(op):
"IFM must be int8"
ifm_dtype = op.ifm.dtype