diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 67c2f053e6..af4dc7a926 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* const armnn::QuantizedLstmInputParams& params, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented"); + auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm); + + // Get input parameters + auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights()); + auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights()); + auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights()); + auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights()); + + auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights()); + auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights()); + auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights()); + auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights()); + + auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias()); + auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias()); + auto cellBias = CreateConstTensorInfo(params.get_CellBias()); + auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias()); + + auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams( + m_flatBufferBuilder, + inputToInputWeights, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToInputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + inputGateBias, + forgetGateBias, + cellBias, + outputGateBias); + + auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer( + m_flatBufferBuilder, + fbQuantizedLstmBaseLayer, + fbQuantizedLstmParams); + + CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, |