aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-23 09:47:43 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-29 16:19:23 +0000
commit5b01a8994caea2857f3b991dc69a814f12ab7743 (patch)
tree434660d1ba049de847ee7b5ff9715bb618421831 /src/armnnSerializer/Serializer.cpp
parent61e71aa399a93cec44b23d43f2293e18d00f8e3a (diff)
downloadarmnn-5b01a8994caea2857f3b991dc69a814f12ab7743.tar.gz
IVGCVSW-3471 Add Serialization support for Quantized_LSTM
* Adds serialization/deserialization support * Adds related Unit test Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: Iaf271aa7d848bc3a69dbbf182389f2241c0ced5f
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp40
1 files changed, 39 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 67c2f053e6..af4dc7a926 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer*
const armnn::QuantizedLstmInputParams& params,
const char* name)
{
- throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented");
+ auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
+
+ // Get input parameters
+ auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights());
+ auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights());
+ auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights());
+ auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights());
+
+ auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights());
+ auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights());
+ auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights());
+ auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights());
+
+ auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias());
+ auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias());
+ auto cellBias = CreateConstTensorInfo(params.get_CellBias());
+ auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias());
+
+ auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
+ m_flatBufferBuilder,
+ inputToInputWeights,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToInputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ inputGateBias,
+ forgetGateBias,
+ cellBias,
+ outputGateBias);
+
+ auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
+ m_flatBufferBuilder,
+ fbQuantizedLstmBaseLayer,
+ fbQuantizedLstmParams);
+
+ CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
}
fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,