aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-03-19 17:22:29 +0000
committerJim Flynn <jim.flynn@arm.com>2019-03-21 16:09:19 +0000
commit11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch)
treef4f4db5192b275be44d96d96c7f3c8c10f15b3f1
parentdb059fd50f9afb398b8b12cd4592323fc8f60d7f (diff)
downloadarmnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd <nina.drozd@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
-rw-r--r--src/armnn/layers/LstmLayer.cpp92
-rw-r--r--src/armnnDeserializer/Deserializer.cpp113
-rw-r--r--src/armnnDeserializer/Deserializer.hpp6
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md1
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs44
-rw-r--r--src/armnnSerializer/Serializer.cpp84
-rw-r--r--src/armnnSerializer/Serializer.hpp5
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp375
9 files changed, 690 insertions, 31 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index fa836d0317..2b99f284e8 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -252,110 +252,144 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
void LstmLayer::Accept(ILayerVisitor& visitor) const
{
LstmInputParams inputParams;
+ ConstTensor inputToInputWeightsTensor;
if (m_CifgParameters.m_InputToInputWeights != nullptr)
{
- ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_InputToInputWeights->Map(true));
+ ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_InputToInputWeights->Map(true));
+ inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
}
+ ConstTensor inputToForgetWeightsTensor;
if (m_BasicParameters.m_InputToForgetWeights != nullptr)
{
- ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToForgetWeights->Map(true));
+ ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToForgetWeights->Map(true));
+ inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
}
+ ConstTensor inputToCellWeightsTensor;
if (m_BasicParameters.m_InputToCellWeights != nullptr)
{
- ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToCellWeights->Map(true));
+ ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToCellWeights->Map(true));
+ inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
}
+ ConstTensor inputToOutputWeightsTensor;
if (m_BasicParameters.m_InputToOutputWeights != nullptr)
{
- ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToOutputWeights->Map(true));
+ ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToOutputWeights->Map(true));
+ inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
}
+ ConstTensor recurrentToInputWeightsTensor;
if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
{
- ConstTensor recurrentToInputWeightsTensor(
+ ConstTensor recurrentToInputWeightsTensorCopy(
m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
m_CifgParameters.m_RecurrentToInputWeights->Map(true));
+ recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
}
+ ConstTensor recurrentToForgetWeightsTensor;
if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
{
- ConstTensor recurrentToForgetWeightsTensor(
+ ConstTensor recurrentToForgetWeightsTensorCopy(
m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToForgetWeights->Map(true));
+ recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
}
+ ConstTensor recurrentToCellWeightsTensor;
if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
{
- ConstTensor recurrentToCellWeightsTensor(
+ ConstTensor recurrentToCellWeightsTensorCopy(
m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToCellWeights->Map(true));
+ recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
}
+ ConstTensor recurrentToOutputWeightsTensor;
if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
{
- ConstTensor recurrentToOutputWeightsTensor(
+ ConstTensor recurrentToOutputWeightsTensorCopy(
m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToOutputWeights->Map(true));
+ recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
}
+ ConstTensor cellToInputWeightsTensor;
if (m_CifgParameters.m_CellToInputWeights != nullptr)
{
- ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_CellToInputWeights->Map(true));
+ ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_CellToInputWeights->Map(true));
+ cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
}
+ ConstTensor cellToForgetWeightsTensor;
if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
{
- ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+ ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+ cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
}
+ ConstTensor cellToOutputWeightsTensor;
if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
{
- ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+ ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+ cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
}
+ ConstTensor inputGateBiasTensor;
if (m_CifgParameters.m_InputGateBias != nullptr)
{
- ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
+ ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
m_CifgParameters.m_InputGateBias->Map(true));
+ inputGateBiasTensor = inputGateBiasTensorCopy;
inputParams.m_InputGateBias = &inputGateBiasTensor;
}
+ ConstTensor forgetGateBiasTensor;
if (m_BasicParameters.m_ForgetGateBias != nullptr)
{
- ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
- m_BasicParameters.m_ForgetGateBias->Map(true));
+ ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
+ m_BasicParameters.m_ForgetGateBias->Map(true));
+ forgetGateBiasTensor = forgetGateBiasTensorCopy;
inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
}
+ ConstTensor cellBiasTensor;
if (m_BasicParameters.m_CellBias != nullptr)
{
- ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(),
- m_BasicParameters.m_CellBias->Map(true));
+ ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(),
+ m_BasicParameters.m_CellBias->Map(true));
+ cellBiasTensor = cellBiasTensorCopy;
inputParams.m_CellBias = &cellBiasTensor;
}
+ ConstTensor outputGateBias;
if (m_BasicParameters.m_OutputGateBias != nullptr)
{
- ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
- m_BasicParameters.m_OutputGateBias->Map(true));
+ ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
+ m_BasicParameters.m_OutputGateBias->Map(true));
+ outputGateBias = outputGateBiasCopy;
inputParams.m_OutputGateBias = &outputGateBias;
}
+ ConstTensor projectionWeightsTensor;
if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
{
- ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionWeights->Map(true));
+ ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionWeights->Map(true));
+ projectionWeightsTensor = projectionWeightsTensorCopy;
inputParams.m_ProjectionWeights = &projectionWeightsTensor;
}
+ ConstTensor projectionBiasTensor;
if (m_ProjectionParameters.m_ProjectionBias != nullptr)
{
- ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionBias->Map(true));
+ ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionBias->Map(true));
+ projectionBiasTensor = projectionBiasTensorCopy;
inputParams.m_ProjectionBias = &projectionBiasTensor;
}
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 152a5b4c93..d64bed7409 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -201,6 +201,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_GatherLayer] = &Deserializer::ParseGather;
m_ParserFunctions[Layer_GreaterLayer] = &Deserializer::ParseGreater;
m_ParserFunctions[Layer_L2NormalizationLayer] = &Deserializer::ParseL2Normalization;
+ m_ParserFunctions[Layer_LstmLayer] = &Deserializer::ParseLstm;
m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum;
m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean;
m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum;
@@ -258,6 +259,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
case Layer::Layer_L2NormalizationLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base();
+ case Layer::Layer_LstmLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base();
case Layer::Layer_MeanLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base();
case Layer::Layer_MinimumLayer:
@@ -1927,4 +1930,114 @@ void Deserializer::ParseSplitter(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescriptorPtr lstmDescriptor)
+{
+ armnn::LstmDescriptor desc;
+
+ desc.m_ActivationFunc = lstmDescriptor->activationFunc();
+ desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell();
+ desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj();
+ desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
+ desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
+ desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
+
+ return desc;
+}
+
+void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 3);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 4);
+
+ auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_LstmLayer();
+ auto layerName = GetLayerName(graph, layerIndex);
+ auto flatBufferDescriptor = flatBufferLayer->descriptor();
+ auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+ auto lstmDescriptor = GetLstmDescriptor(flatBufferDescriptor);
+
+ armnn::LstmInputParams lstmInputParams;
+
+ armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+ armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+ armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+ armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+ armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+ armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+ armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
+ armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
+ armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+ lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+ lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+ lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+ lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ lstmInputParams.m_ForgetGateBias = &forgetGateBias;
+ lstmInputParams.m_CellBias = &cellBias;
+ lstmInputParams.m_OutputGateBias = &outputGateBias;
+
+ armnn::ConstTensor inputToInputWeights;
+ armnn::ConstTensor recurrentToInputWeights;
+ armnn::ConstTensor cellToInputWeights;
+ armnn::ConstTensor inputGateBias;
+ if (!lstmDescriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+ recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+ cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+ inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+ lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
+ lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ lstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+ lstmInputParams.m_InputGateBias = &inputGateBias;
+ }
+
+ armnn::ConstTensor projectionWeights;
+ armnn::ConstTensor projectionBias;
+ if (lstmDescriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+ projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+ lstmInputParams.m_ProjectionWeights = &projectionWeights;
+ lstmInputParams.m_ProjectionBias = &projectionBias;
+ }
+
+ armnn::ConstTensor cellToForgetWeights;
+ armnn::ConstTensor cellToOutputWeights;
+ if (lstmDescriptor.m_PeepholeEnabled)
+ {
+ cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+ cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+ lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+ lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+ }
+
+ IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1);
+
+ armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]);
+ layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2);
+
+ armnn::TensorInfo outputTensorInfo3 = ToTensorInfo(outputs[2]);
+ layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo3);
+
+ armnn::TensorInfo outputTensorInfo4 = ToTensorInfo(outputs[3]);
+ layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo4);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
} // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index effc7ae144..6454643f98 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -22,6 +22,8 @@ public:
using TensorRawPtr = const armnnSerializer::TensorInfo *;
using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
+ using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
+ using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
@@ -58,6 +60,9 @@ public:
unsigned int layerIndex);
static armnn::NormalizationDescriptor GetNormalizationDescriptor(
NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
+ static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
+ static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
+ LstmInputParamsPtr lstmInputParams);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -94,6 +99,7 @@ private:
void ParseMerger(GraphPtr graph, unsigned int layerIndex);
void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
+ void ParseLstm(GraphPtr graph, unsigned int layerIndex);
void ParsePad(GraphPtr graph, unsigned int layerIndex);
void ParsePermute(GraphPtr graph, unsigned int layerIndex);
void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 48b8c88103..d53252ec00 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -21,6 +21,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Gather
* Greater
* L2Normalization
+* Lstm
* Maximum
* Mean
* Merger
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index a11eeadf12..2cceaae031 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -115,7 +115,8 @@ enum LayerType : uint {
Merger = 30,
L2Normalization = 31,
Splitter = 32,
- DetectionPostProcess = 33
+ DetectionPostProcess = 33,
+ Lstm = 34
}
// Base layer table to be used as part of other layers
@@ -475,6 +476,44 @@ table DetectionPostProcessDescriptor {
scaleH:float;
}
+table LstmInputParams {
+ inputToForgetWeights:ConstTensor;
+ inputToCellWeights:ConstTensor;
+ inputToOutputWeights:ConstTensor;
+ recurrentToForgetWeights:ConstTensor;
+ recurrentToCellWeights:ConstTensor;
+ recurrentToOutputWeights:ConstTensor;
+ forgetGateBias:ConstTensor;
+ cellBias:ConstTensor;
+ outputGateBias:ConstTensor;
+
+ inputToInputWeights:ConstTensor;
+ recurrentToInputWeights:ConstTensor;
+ cellToInputWeights:ConstTensor;
+ inputGateBias:ConstTensor;
+
+ projectionWeights:ConstTensor;
+ projectionBias:ConstTensor;
+
+ cellToForgetWeights:ConstTensor;
+ cellToOutputWeights:ConstTensor;
+}
+
+table LstmDescriptor {
+ activationFunc:uint;
+ clippingThresCell:float;
+ clippingThresProj:float;
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+}
+
+table LstmLayer {
+ base:LayerBase;
+ descriptor:LstmDescriptor;
+ inputParams:LstmInputParams;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -509,7 +548,8 @@ union Layer {
MergerLayer,
L2NormalizationLayer,
SplitterLayer,
- DetectionPostProcessLayer
+ DetectionPostProcessLayer,
+ LstmLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index a27cbc03ba..2fd840258e 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -375,6 +375,90 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer
CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
}
+void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params, const char* name)
+{
+ auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
+
+ auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_ActivationFunc,
+ descriptor.m_ClippingThresCell,
+ descriptor.m_ClippingThresProj,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled);
+
+ // Get mandatory input parameters
+ auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+ auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+ auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+ auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+ auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+ auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+ auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+ auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+ auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+ //Define optional parameters, these will be set depending on configuration in Lstm descriptor
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+ recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+ cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
+ inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+ }
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+ projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+ }
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+ cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+ }
+
+ auto fbLstmParams = serializer::CreateLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ cellToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToForgetWeights,
+ cellToOutputWeights);
+
+ auto fbLstmLayer = serializer::CreateLstmLayer(
+ m_flatBufferBuilder,
+ fbLstmBaseLayer,
+ fbLstmDescriptor,
+ fbLstmParams);
+
+ CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
+}
+
void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
{
auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 71066d2699..4573bfdd11 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -111,6 +111,11 @@ public:
const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
const char* name = nullptr) override;
+ void VisitLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name = nullptr) override;
+
void VisitMeanLayer(const armnn::IConnectableLayer* layer,
const armnn::MeanDescriptor& descriptor,
const char* name) override;
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 4e127b3f9f..7686d5c24c 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -21,6 +21,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Gather
* Greater
* L2Normalization
+* Lstm
* Maximum
* Mean
* Merger
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index f40c02dfde..e3ce6d29d3 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2047,4 +2047,379 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork)
deserializedNetwork->Accept(verifier);
}
+class VerifyLstmLayer : public LayerVerifierBase
+{
+public:
+ VerifyLstmLayer(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& inputParams) :
+ LayerVerifierBase(layerName, inputInfos, outputInfos), m_Descriptor(descriptor), m_InputParams(inputParams)
+ {
+ }
+ void VisitLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name)
+ {
+ VerifyNameAndConnections(layer, name);
+ VerifyDescriptor(descriptor);
+ VerifyInputParameters(params);
+ }
+protected:
+ void VerifyDescriptor(const armnn::LstmDescriptor& descriptor)
+ {
+ BOOST_TEST(m_Descriptor.m_ActivationFunc == descriptor.m_ActivationFunc);
+ BOOST_TEST(m_Descriptor.m_ClippingThresCell == descriptor.m_ClippingThresCell);
+ BOOST_TEST(m_Descriptor.m_ClippingThresProj == descriptor.m_ClippingThresProj);
+ BOOST_TEST(m_Descriptor.m_CifgEnabled == descriptor.m_CifgEnabled);
+ BOOST_TEST(m_Descriptor.m_PeepholeEnabled = descriptor.m_PeepholeEnabled);
+ BOOST_TEST(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
+ }
+ void VerifyInputParameters(const armnn::LstmInputParams& params)
+ {
+ VerifyConstTensors(
+ "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights);
+ VerifyConstTensors(
+ "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights);
+ VerifyConstTensors(
+ "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights);
+ VerifyConstTensors(
+ "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights);
+ VerifyConstTensors(
+ "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights);
+ VerifyConstTensors(
+ "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights);
+ VerifyConstTensors(
+ "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights);
+ VerifyConstTensors(
+ "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights);
+ VerifyConstTensors(
+ "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights);
+ VerifyConstTensors(
+ "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias);
+ VerifyConstTensors(
+ "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias);
+ VerifyConstTensors(
+ "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias);
+ VerifyConstTensors(
+ "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias);
+ VerifyConstTensors(
+ "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights);
+ VerifyConstTensors(
+ "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias);
+ }
+ void VerifyConstTensors(const std::string& tensorName,
+ const armnn::ConstTensor* expectedPtr,
+ const armnn::ConstTensor* actualPtr)
+ {
+ if (expectedPtr == nullptr)
+ {
+ BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist");
+ }
+ else
+ {
+ BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set");
+ if (actualPtr != nullptr)
+ {
+ const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
+ const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
+
+ BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
+ tensorName + " shapes don't match");
+ BOOST_CHECK_MESSAGE(
+ GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
+ tensorName + " data types don't match");
+
+ BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
+ tensorName + " (GetNumBytes) data sizes do not match");
+ if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
+ {
+ //check the data is identical
+ const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
+ const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
+ bool same = true;
+ for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
+ {
+ same = expectedData[i] == actualData[i];
+ if (!same)
+ {
+ break;
+ }
+ }
+ BOOST_CHECK_MESSAGE(same, tensorName + " data does not match");
+ }
+ }
+ }
+ }
+private:
+ armnn::LstmDescriptor m_Descriptor;
+ armnn::LstmInputParams m_InputParams;
+};
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
+{
+ armnn::LstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = true;
+
+ const uint32_t batchSize = 1;
+ const uint32_t inputSize = 2;
+ const uint32_t numUnits = 4;
+ const uint32_t outputSize = numUnits;
+
+ armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32);
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData);
+
+ std::vector<float> forgetGateBiasData(numUnits, 1.0f);
+ armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData);
+
+ std::vector<float> cellBiasData(numUnits, 0.0f);
+ armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData);
+
+ std::vector<float> outputGateBiasData(numUnits, 0.0f);
+ armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("lstm");
+ armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 3 }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+ lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyLstmLayer checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
+{
+ armnn::LstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+
+ const uint32_t batchSize = 2;
+ const uint32_t inputSize = 5;
+ const uint32_t numUnits = 20;
+ const uint32_t outputSize = 16;
+
+ armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData);
+
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32);
+ std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData);
+
+ std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellBias(tensorInfo20, cellBiasData);
+
+ std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData);
+
+ armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData);
+
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData);
+
+ std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData);
+
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32);
+ std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements());
+ armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData);
+
+ armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32);
+ std::vector<float> projectionBiasData(outputSize, 0.f);
+ armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // additional params because: descriptor.m_CifgEnabled = false
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // additional params because: descriptor.m_ProjectionEnabled = true
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // additional params because: descriptor.m_PeepholeEnabled = true
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("lstm");
+ armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 4 }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+ lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyLstmLayer checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->Accept(checker);
+}
+
BOOST_AUTO_TEST_SUITE_END()