diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index a27cbc03ba..2fd840258e 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -375,6 +375,90 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer); } +void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, const char* name) +{ + auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm); + + auto fbLstmDescriptor = serializer::CreateLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_ActivationFunc, + descriptor.m_ClippingThresCell, + descriptor.m_ClippingThresProj, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled); + + // Get mandatory input parameters + auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); + auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights); + auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights); + auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights); + auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights); + auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights); + auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias); + auto cellBias = CreateConstTensorInfo(*params.m_CellBias); + auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias); + + //Define optional parameters, these will be set depending on configuration in Lstm descriptor + flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputGateBias; + flatbuffers::Offset<serializer::ConstTensor> projectionWeights; + flatbuffers::Offset<serializer::ConstTensor> projectionBias; + flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights); + recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights); + cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights); + inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias); + } + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights); + projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias); + } + + if (descriptor.m_PeepholeEnabled) + { + cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights); + cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); + } + + auto fbLstmParams = serializer::CreateLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights); + + auto fbLstmLayer = serializer::CreateLstmLayer( + m_flatBufferBuilder, + fbLstmBaseLayer, + fbLstmDescriptor, + fbLstmParams); + + CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer); +} + void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name) { auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum); |