diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-23 14:47:49 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-28 12:03:02 +0100 |
commit | a0162e17c56538ee6d72ecce4c3e0836cbb34c56 (patch) | |
tree | c47230c4024d7e79cacb39dafe179cdcf4571ade /src/armnnSerializer/Serializer.cpp | |
parent | 996f0f59e5b8a9ac73503814f7aadff4ef74cd35 (diff) | |
download | armnn-a0162e17c56538ee6d72ecce4c3e0836cbb34c56.tar.gz |
MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index fd7f8dc7dc..44cd1800c4 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1648,6 +1648,123 @@ void SerializerStrategy::SerializeQuantizedLstmLayer(const armnn::IConnectableLa CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } +void SerializerStrategy::SerializeUnidirectionalSequenceLstmLayer( + const armnn::IConnectableLayer* layer, + const armnn::UnidirectionalSequenceLstmDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name) +{ + IgnoreUnused(name); + + auto fbUnidirectionalSequenceLstmBaseLayer = + CreateLayerBase(layer, serializer::LayerType::LayerType_UnidirectionalSequenceLstm); + + auto fbUnidirectionalSequenceLstmDescriptor = serializer::CreateUnidirectionalSequenceLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_ActivationFunc, + descriptor.m_ClippingThresCell, + descriptor.m_ClippingThresProj, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled, + descriptor.m_TimeMajor); + + // Index for constants vector + std::size_t i = 0; + + // Get mandatory/basic input parameters + auto inputToForgetWeights = CreateConstTensorInfo(constants[i++]); //InputToForgetWeights + auto inputToCellWeights = CreateConstTensorInfo(constants[i++]); //InputToCellWeights + auto inputToOutputWeights = CreateConstTensorInfo(constants[i++]); //InputToOutputWeights + auto recurrentToForgetWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToForgetWeights + auto recurrentToCellWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToCellWeights + auto recurrentToOutputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToOutputWeights + auto forgetGateBias = CreateConstTensorInfo(constants[i++]); //ForgetGateBias + auto cellBias = CreateConstTensorInfo(constants[i++]); //CellBias + auto outputGateBias = CreateConstTensorInfo(constants[i++]); //OutputGateBias + + //Define optional parameters, these will be set depending on configuration in Lstm descriptor + flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputGateBias; + flatbuffers::Offset<serializer::ConstTensor> projectionWeights; + flatbuffers::Offset<serializer::ConstTensor> projectionBias; + flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights; + flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(constants[i++]); //InputToInputWeights + recurrentToInputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToInputWeights + inputGateBias = CreateConstTensorInfo(constants[i++]); //InputGateBias + } + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + cellToInputWeights = CreateConstTensorInfo(constants[i++]); //CellToInputWeights + } + cellToForgetWeights = CreateConstTensorInfo(constants[i++]); //CellToForgetWeights + cellToOutputWeights = CreateConstTensorInfo(constants[i++]); //CellToOutputWeights + } + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(constants[i++]); //ProjectionWeights + projectionBias = CreateConstTensorInfo(constants[i++]); //ProjectionBias + } + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //InputLayerNormWeights + } + forgetLayerNormWeights = CreateConstTensorInfo(constants[i++]); //ForgetLayerNormWeights + cellLayerNormWeights = CreateConstTensorInfo(constants[i++]); //CellLayerNormWeights + outputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //OutputLayerNormWeights + } + + auto fbUnidirectionalSequenceLstmParams = serializer::CreateLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); + + auto fbUnidirectionalSequenceLstmLayer = serializer::CreateUnidirectionalSequenceLstmLayer( + m_flatBufferBuilder, + fbUnidirectionalSequenceLstmBaseLayer, + fbUnidirectionalSequenceLstmDescriptor, + fbUnidirectionalSequenceLstmParams); + + CreateAnyLayer(fbUnidirectionalSequenceLstmLayer.o, serializer::Layer::Layer_UnidirectionalSequenceLstmLayer); +} + fb::Offset<serializer::LayerBase> SerializerStrategy::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { @@ -2234,6 +2351,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer, SerializeTransposeConvolution2dLayer(layer, layerDescriptor, constants, name); break; } + case armnn::LayerType::UnidirectionalSequenceLstm : + { + const armnn::UnidirectionalSequenceLstmDescriptor& layerDescriptor = + static_cast<const armnn::UnidirectionalSequenceLstmDescriptor&>(descriptor); + SerializeUnidirectionalSequenceLstmLayer(layer, layerDescriptor, constants, name); + break; + } default: { throw InvalidArgumentException( |