aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp40
1 files changed, 39 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 67c2f053e6..af4dc7a926 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer*
const armnn::QuantizedLstmInputParams& params,
const char* name)
{
- throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented");
+ auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
+
+ // Get input parameters
+ auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights());
+ auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights());
+ auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights());
+ auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights());
+
+ auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights());
+ auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights());
+ auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights());
+ auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights());
+
+ auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias());
+ auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias());
+ auto cellBias = CreateConstTensorInfo(params.get_CellBias());
+ auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias());
+
+ auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
+ m_flatBufferBuilder,
+ inputToInputWeights,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToInputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ inputGateBias,
+ forgetGateBias,
+ cellBias,
+ outputGateBias);
+
+ auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
+ m_flatBufferBuilder,
+ fbQuantizedLstmBaseLayer,
+ fbQuantizedLstmParams);
+
+ CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
}
fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,