diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 119 |
1 files changed, 117 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 355673697f..c4d3cfb5dd 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmInputParams& params, const char* name) { - IgnoreUnused(layer, descriptor, params, name); + IgnoreUnused(name); + + auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm); + + auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled, + descriptor.m_CellClip, + descriptor.m_ProjectionClip, + descriptor.m_InputIntermediateScale, + descriptor.m_ForgetIntermediateScale, + descriptor.m_CellIntermediateScale, + descriptor.m_OutputIntermediateScale, + descriptor.m_HiddenStateZeroPoint, + descriptor.m_HiddenStateScale + ); + + // Mandatory params + auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); + auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights); + auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights); + auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights); + auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights); + auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights); + auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias); + auto cellBias = CreateConstTensorInfo(*params.m_CellBias); + auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias); + + // CIFG + flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputGateBias; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights); + recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights); + inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias); + } + + // Projectiom + flatbuffers::Offset<serializer::ConstTensor> projectionWeights; + flatbuffers::Offset<serializer::ConstTensor> projectionBias; + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights); + projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias); + } + + // Peephole + flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights; + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights); + } + + cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights); + cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); + } + + // Layer norm + flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights; + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights)); + } + + forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights); + cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights); + outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights); + } + + auto fbQLstmParams = serializer::CreateQLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToInputWeights, + cellToForgetWeights, + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); + + auto fbQLstmLayer = serializer::CreateQLstmLayer( + m_flatBufferBuilder, + fbQLstmBaseLayer, + fbQLstmDescriptor, + fbQLstmParams); - throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented"); + CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer); } void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer, |