aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp124
1 files changed, 124 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index fd7f8dc7dc..44cd1800c4 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1648,6 +1648,123 @@ void SerializerStrategy::SerializeQuantizedLstmLayer(const armnn::IConnectableLa
CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
}
+void SerializerStrategy::SerializeUnidirectionalSequenceLstmLayer(
+ const armnn::IConnectableLayer* layer,
+ const armnn::UnidirectionalSequenceLstmDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ auto fbUnidirectionalSequenceLstmBaseLayer =
+ CreateLayerBase(layer, serializer::LayerType::LayerType_UnidirectionalSequenceLstm);
+
+ auto fbUnidirectionalSequenceLstmDescriptor = serializer::CreateUnidirectionalSequenceLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_ActivationFunc,
+ descriptor.m_ClippingThresCell,
+ descriptor.m_ClippingThresProj,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled,
+ descriptor.m_TimeMajor);
+
+ // Index for constants vector
+ std::size_t i = 0;
+
+ // Get mandatory/basic input parameters
+ auto inputToForgetWeights = CreateConstTensorInfo(constants[i++]); //InputToForgetWeights
+ auto inputToCellWeights = CreateConstTensorInfo(constants[i++]); //InputToCellWeights
+ auto inputToOutputWeights = CreateConstTensorInfo(constants[i++]); //InputToOutputWeights
+ auto recurrentToForgetWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToForgetWeights
+ auto recurrentToCellWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToCellWeights
+ auto recurrentToOutputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToOutputWeights
+ auto forgetGateBias = CreateConstTensorInfo(constants[i++]); //ForgetGateBias
+ auto cellBias = CreateConstTensorInfo(constants[i++]); //CellBias
+ auto outputGateBias = CreateConstTensorInfo(constants[i++]); //OutputGateBias
+
+ //Define optional parameters, these will be set depending on configuration in Lstm descriptor
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(constants[i++]); //InputToInputWeights
+ recurrentToInputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToInputWeights
+ inputGateBias = CreateConstTensorInfo(constants[i++]); //InputGateBias
+ }
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = CreateConstTensorInfo(constants[i++]); //CellToInputWeights
+ }
+ cellToForgetWeights = CreateConstTensorInfo(constants[i++]); //CellToForgetWeights
+ cellToOutputWeights = CreateConstTensorInfo(constants[i++]); //CellToOutputWeights
+ }
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(constants[i++]); //ProjectionWeights
+ projectionBias = CreateConstTensorInfo(constants[i++]); //ProjectionBias
+ }
+
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //InputLayerNormWeights
+ }
+ forgetLayerNormWeights = CreateConstTensorInfo(constants[i++]); //ForgetLayerNormWeights
+ cellLayerNormWeights = CreateConstTensorInfo(constants[i++]); //CellLayerNormWeights
+ outputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //OutputLayerNormWeights
+ }
+
+ auto fbUnidirectionalSequenceLstmParams = serializer::CreateLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ cellToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToForgetWeights,
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
+
+ auto fbUnidirectionalSequenceLstmLayer = serializer::CreateUnidirectionalSequenceLstmLayer(
+ m_flatBufferBuilder,
+ fbUnidirectionalSequenceLstmBaseLayer,
+ fbUnidirectionalSequenceLstmDescriptor,
+ fbUnidirectionalSequenceLstmParams);
+
+ CreateAnyLayer(fbUnidirectionalSequenceLstmLayer.o, serializer::Layer::Layer_UnidirectionalSequenceLstmLayer);
+}
+
fb::Offset<serializer::LayerBase> SerializerStrategy::CreateLayerBase(const IConnectableLayer* layer,
const serializer::LayerType layerType)
{
@@ -2234,6 +2351,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
SerializeTransposeConvolution2dLayer(layer, layerDescriptor, constants, name);
break;
}
+ case armnn::LayerType::UnidirectionalSequenceLstm :
+ {
+ const armnn::UnidirectionalSequenceLstmDescriptor& layerDescriptor =
+ static_cast<const armnn::UnidirectionalSequenceLstmDescriptor&>(descriptor);
+ SerializeUnidirectionalSequenceLstmLayer(layer, layerDescriptor, constants, name);
+ break;
+ }
default:
{
throw InvalidArgumentException(