diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-07-23 09:47:43 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2019-07-29 16:19:23 +0000 |
commit | 5b01a8994caea2857f3b991dc69a814f12ab7743 (patch) | |
tree | 434660d1ba049de847ee7b5ff9715bb618421831 /src/armnnSerializer/Serializer.cpp | |
parent | 61e71aa399a93cec44b23d43f2293e18d00f8e3a (diff) | |
download | armnn-5b01a8994caea2857f3b991dc69a814f12ab7743.tar.gz |
IVGCVSW-3471 Add Serialization support for Quantized_LSTM
* Adds serialization/deserialization support
* Adds related Unit test
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: Iaf271aa7d848bc3a69dbbf182389f2241c0ced5f
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 67c2f053e6..af4dc7a926 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* const armnn::QuantizedLstmInputParams& params, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented"); + auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm); + + // Get input parameters + auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights()); + auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights()); + auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights()); + auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights()); + + auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights()); + auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights()); + auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights()); + auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights()); + + auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias()); + auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias()); + auto cellBias = CreateConstTensorInfo(params.get_CellBias()); + auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias()); + + auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams( + m_flatBufferBuilder, + inputToInputWeights, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToInputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + inputGateBias, + forgetGateBias, + cellBias, + outputGateBias); + + auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer( + m_flatBufferBuilder, + fbQuantizedLstmBaseLayer, + fbQuantizedLstmParams); + + CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, |