diff options
author | William Isaksson <william.isaksson@arm.com> | 2023-07-17 13:03:09 +0000 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-07-24 13:09:58 +0000 |
commit | 2f9b6874a227d8fa056c2e2fd01e8c80824ee0bc (patch) | |
tree | ebde7518ef8380ce4e250610528b4030d140e12a | |
parent | 9cf63a3612491198a39f6bd1f4a587589b3ac20a (diff) | |
download | ethos-u-vela-2f9b6874a227d8fa056c2e2fd01e8c80824ee0bc.tar.gz |
MLBEDSW-7165: Update to TensorFlow 2.12
- Updated FlatBuffers files using TensorFlow 2.12.0 schema
- Added restriction for UnidirectionalSequenceLSTM to have 2D recurrent
weights to handle that diagonal_recurrent_tensors attr is not
currently supported.
Change-Id: I104fd1f52485b9b83d644772dbcdeea2d17585f0
Signed-off-by: William Isaksson <william.isaksson@arm.com>
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py | 12 | ||||
-rw-r--r-- | ethosu/vela/tflite_mapping.py | 9 | ||||
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 7 |
4 files changed, 28 insertions, 3 deletions
@@ -45,7 +45,8 @@ The tool has limited functionality for compiling a (EXPERIMENTAL). ## TensorFlow Support -* Vela 3.8.0 to current supports TensorFlow 2.11 +* Vela 3.9.0 to current supports TensorFlow 2.12 +* Vela 3.8.0 supports TensorFlow 2.11 * Vela 3.6.0 to 3.7.0 supports TensorFlow 2.10 * Vela 3.5.0 supports TensorFlow 2.9 * Vela 3.4.0 supports TensorFlow 2.8 diff --git a/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py b/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py index 16641a7..73d298d 100644 --- a/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py +++ b/ethosu/vela/tflite/UnidirectionalSequenceLSTMOptions.py @@ -63,7 +63,14 @@ class UnidirectionalSequenceLSTMOptions(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False -def UnidirectionalSequenceLSTMOptionsStart(builder): builder.StartObject(5) + # UnidirectionalSequenceLSTMOptions + def DiagonalRecurrentTensors(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def UnidirectionalSequenceLSTMOptionsStart(builder): builder.StartObject(6) def Start(builder): return UnidirectionalSequenceLSTMOptionsStart(builder) def UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction(builder, fusedActivationFunction): builder.PrependInt8Slot(0, fusedActivationFunction, 0) @@ -81,6 +88,9 @@ def AddTimeMajor(builder, timeMajor): def UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): builder.PrependBoolSlot(4, asymmetricQuantizeInputs, 0) def AddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): return UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs) +def UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors): builder.PrependBoolSlot(5, diagonalRecurrentTensors, 0) +def AddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors): + return UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors(builder, diagonalRecurrentTensors) def UnidirectionalSequenceLSTMOptionsEnd(builder): return builder.EndObject() def End(builder): return UnidirectionalSequenceLSTMOptionsEnd(builder)
\ No newline at end of file diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index 83c55bb..dba30c4 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -716,7 +716,14 @@ builtin_operator_map = { Op.UnidirectionalSequenceLstm, OptionsSerializer( "UnidirectionalSequenceLSTMOptions", - ("asymmetric_quantize_inputs", "cell_clip", fused_act, "proj_clip", "time_major"), + ( + "asymmetric_quantize_inputs", + "cell_clip", + "diagonal_recurrent_tensors", + fused_act, + "proj_clip", + "time_major", + ), ), TFLITE_IFM_WEIGHTS_INDICES, ), diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 7d54400..723c5f2 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -331,6 +331,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weight_dimensions) # Rsqrt specific checks self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8) @@ -967,6 +968,12 @@ class TFLiteSupportedOperators: return valid, "Op has missing weights" @staticmethod + def constraint_lstm_weight_dimensions(op): + "All recurrent weights must be 2D" + valid = all([len(input.shape) == 2 for input in op.inputs[5:9]]) + return valid, "Op recurrent weights are not 2D" + + @staticmethod def constraint_rsqrt_input_int8(op): "IFM must be int8" ifm_dtype = op.ifm.dtype |