From 5b01a8994caea2857f3b991dc69a814f12ab7743 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Tue, 23 Jul 2019 09:47:43 +0100 Subject: IVGCVSW-3471 Add Serialization support for Quantized_LSTM * Adds serialization/deserialization support * Adds related Unit test Signed-off-by: Jan Eilers Change-Id: Iaf271aa7d848bc3a69dbbf182389f2241c0ced5f --- src/armnnSerializer/Serializer.cpp | 40 +++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) (limited to 'src/armnnSerializer/Serializer.cpp') diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 67c2f053e6..af4dc7a926 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* const armnn::QuantizedLstmInputParams& params, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented"); + auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm); + + // Get input parameters + auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights()); + auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights()); + auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights()); + auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights()); + + auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights()); + auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights()); + auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights()); + auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights()); + + auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias()); + auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias()); + auto cellBias = CreateConstTensorInfo(params.get_CellBias()); + auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias()); + + auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams( + m_flatBufferBuilder, + inputToInputWeights, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToInputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + inputGateBias, + forgetGateBias, + cellBias, + outputGateBias); + + auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer( + m_flatBufferBuilder, + fbQuantizedLstmBaseLayer, + fbQuantizedLstmParams); + + CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } fb::Offset SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, -- cgit v1.2.1