aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp119
1 files changed, 117 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 355673697f..c4d3cfb5dd 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
const armnn::LstmInputParams& params,
const char* name)
{
- IgnoreUnused(layer, descriptor, params, name);
+ IgnoreUnused(name);
+
+ auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
+
+ auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled,
+ descriptor.m_CellClip,
+ descriptor.m_ProjectionClip,
+ descriptor.m_InputIntermediateScale,
+ descriptor.m_ForgetIntermediateScale,
+ descriptor.m_CellIntermediateScale,
+ descriptor.m_OutputIntermediateScale,
+ descriptor.m_HiddenStateZeroPoint,
+ descriptor.m_HiddenStateScale
+ );
+
+ // Mandatory params
+ auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+ auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+ auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+ auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+ auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+ auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+ auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+ auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+ auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+ // CIFG
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+ recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+ inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+ }
+
+ // Projectiom
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+ projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+ }
+
+ // Peephole
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
+ }
+
+ cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+ cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+ }
+
+ // Layer norm
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
+
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+ }
+
+ forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+ cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+ outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+ }
+
+ auto fbQLstmParams = serializer::CreateQLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToInputWeights,
+ cellToForgetWeights,
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
+
+ auto fbQLstmLayer = serializer::CreateQLstmLayer(
+ m_flatBufferBuilder,
+ fbQLstmBaseLayer,
+ fbQLstmDescriptor,
+ fbQLstmParams);
- throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented");
+ CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
}
void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,