diff options
author | James Conroy <james.conroy@arm.com> | 2020-05-13 10:27:58 +0100 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-05-13 23:06:38 +0000 |
commit | 8d33318a7ac33d90ed79701ff717de8d9940cc67 (patch) | |
tree | 2cf4140ec37b5b0a43b9618bab7f4f8076b5f4ab /src/armnnSerializer/Serializer.cpp | |
parent | 5061601fb6833dda20a6097af6a92e5e07310f25 (diff) | |
download | armnn-8d33318a7ac33d90ed79701ff717de8d9940cc67.tar.gz |
IVGCVSW-4777 Add QLstm serialization support
* Adds serialization/deserilization for QLstm.
* 3 unit tests: basic, layer norm and advanced.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 119 |
1 files changed, 117 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 355673697f..c4d3cfb5dd 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmInputParams& params, const char* name) { - IgnoreUnused(layer, descriptor, params, name); + IgnoreUnused(name); + + auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm); + + auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled, + descriptor.m_CellClip, + descriptor.m_ProjectionClip, + descriptor.m_InputIntermediateScale, + descriptor.m_ForgetIntermediateScale, + descriptor.m_CellIntermediateScale, + descriptor.m_OutputIntermediateScale, + descriptor.m_HiddenStateZeroPoint, + descriptor.m_HiddenStateScale + ); + + // Mandatory params + auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); + auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights); + auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights); + auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights); + auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights); + auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights); + auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias); + auto cellBias = CreateConstTensorInfo(*params.m_CellBias); + auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias); + + // CIFG + flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputGateBias; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights); + recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights); + inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias); + } + + // Projectiom + flatbuffers::Offset<serializer::ConstTensor> projectionWeights; + flatbuffers::Offset<serializer::ConstTensor> projectionBias; + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights); + projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias); + } + + // Peephole + flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights; + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights); + } + + cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights); + cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); + } + + // Layer norm + flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights; + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights)); + } + + forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights); + cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights); + outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights); + } + + auto fbQLstmParams = serializer::CreateQLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToInputWeights, + cellToForgetWeights, + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); + + auto fbQLstmLayer = serializer::CreateQLstmLayer( + m_flatBufferBuilder, + fbQLstmBaseLayer, + fbQLstmDescriptor, + fbQLstmParams); - throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented"); + CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer); } void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer, |