aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-03-19 17:22:29 +0000
committerJim Flynn <jim.flynn@arm.com>2019-03-21 16:09:19 +0000
commit11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch)
treef4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnnSerializer/Serializer.cpp
parentdb059fd50f9afb398b8b12cd4592323fc8f60d7f (diff)
downloadarmnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd <nina.drozd@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp84
1 files changed, 84 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index a27cbc03ba..2fd840258e 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -375,6 +375,90 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer
CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
}
+void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params, const char* name)
+{
+ auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
+
+ auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_ActivationFunc,
+ descriptor.m_ClippingThresCell,
+ descriptor.m_ClippingThresProj,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled);
+
+ // Get mandatory input parameters
+ auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+ auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+ auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+ auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+ auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+ auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+ auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+ auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+ auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+ //Define optional parameters, these will be set depending on configuration in Lstm descriptor
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+ recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+ cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
+ inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+ }
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+ projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+ }
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+ cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+ }
+
+ auto fbLstmParams = serializer::CreateLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ cellToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToForgetWeights,
+ cellToOutputWeights);
+
+ auto fbLstmLayer = serializer::CreateLstmLayer(
+ m_flatBufferBuilder,
+ fbLstmBaseLayer,
+ fbLstmDescriptor,
+ fbLstmParams);
+
+ CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
+}
+
void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
{
auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);