aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2020-05-13 10:27:58 +0100
committerJames Conroy <james.conroy@arm.com>2020-05-13 23:06:38 +0000
commit8d33318a7ac33d90ed79701ff717de8d9940cc67 (patch)
tree2cf4140ec37b5b0a43b9618bab7f4f8076b5f4ab
parent5061601fb6833dda20a6097af6a92e5e07310f25 (diff)
downloadarmnn-8d33318a7ac33d90ed79701ff717de8d9940cc67.tar.gz
IVGCVSW-4777 Add QLstm serialization support
* Adds serialization/deserilization for QLstm. * 3 unit tests: basic, layer norm and advanced. Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380
-rw-r--r--src/armnnDeserializer/Deserializer.cpp152
-rw-r--r--src/armnnDeserializer/Deserializer.hpp3
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md1
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs93
-rw-r--r--src/armnnSerializer/Serializer.cpp119
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp657
7 files changed, 1008 insertions, 18 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 42b0052b03..36beebc1cd 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_PermuteLayer] = &Deserializer::ParsePermute;
m_ParserFunctions[Layer_Pooling2dLayer] = &Deserializer::ParsePooling2d;
m_ParserFunctions[Layer_PreluLayer] = &Deserializer::ParsePrelu;
+ m_ParserFunctions[Layer_QLstmLayer] = &Deserializer::ParseQLstm;
m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize;
m_ParserFunctions[Layer_QuantizedLstmLayer] = &Deserializer::ParseQuantizedLstm;
m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape;
@@ -322,6 +323,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
case Layer::Layer_PreluLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
+ case Layer::Layer_QLstmLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
case Layer::Layer_QuantizeLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
case Layer::Layer_QuantizedLstmLayer:
@@ -2610,6 +2613,155 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+armnn::QLstmDescriptor Deserializer::GetQLstmDescriptor(Deserializer::QLstmDescriptorPtr qLstmDescriptor)
+{
+ armnn::QLstmDescriptor desc;
+
+ desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled();
+ desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled();
+ desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
+ desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled();
+
+ desc.m_CellClip = qLstmDescriptor->cellClip();
+ desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
+
+ desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale();
+ desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
+ desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale();
+ desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
+
+ desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale();
+ desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
+
+ return desc;
+}
+
+void Deserializer::ParseQLstm(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 3);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 3);
+
+ auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer();
+ auto layerName = GetLayerName(graph, layerIndex);
+ auto flatBufferDescriptor = flatBufferLayer->descriptor();
+ auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+ auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor);
+ armnn::LstmInputParams qLstmInputParams;
+
+ // Mandatory params
+ armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+ armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+ armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+ armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+ armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+ armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+ armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
+ armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
+ armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+ qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+ qLstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+ qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+ qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ qLstmInputParams.m_ForgetGateBias = &forgetGateBias;
+ qLstmInputParams.m_CellBias = &cellBias;
+ qLstmInputParams.m_OutputGateBias = &outputGateBias;
+
+ // Optional CIFG params
+ armnn::ConstTensor inputToInputWeights;
+ armnn::ConstTensor recurrentToInputWeights;
+ armnn::ConstTensor inputGateBias;
+
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+ recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+ inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+ qLstmInputParams.m_InputToInputWeights = &inputToInputWeights;
+ qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ qLstmInputParams.m_InputGateBias = &inputGateBias;
+ }
+
+ // Optional projection params
+ armnn::ConstTensor projectionWeights;
+ armnn::ConstTensor projectionBias;
+
+ if (qLstmDescriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+ projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+ qLstmInputParams.m_ProjectionWeights = &projectionWeights;
+ qLstmInputParams.m_ProjectionBias = &projectionBias;
+ }
+
+ // Optional peephole params
+ armnn::ConstTensor cellToInputWeights;
+ armnn::ConstTensor cellToForgetWeights;
+ armnn::ConstTensor cellToOutputWeights;
+
+ if (qLstmDescriptor.m_PeepholeEnabled)
+ {
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+ qLstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+ }
+
+ cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+ cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+ qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+ qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+ }
+
+ // Optional layer norm params
+ armnn::ConstTensor inputLayerNormWeights;
+ armnn::ConstTensor forgetLayerNormWeights;
+ armnn::ConstTensor cellLayerNormWeights;
+ armnn::ConstTensor outputLayerNormWeights;
+
+ if (qLstmDescriptor.m_LayerNormEnabled)
+ {
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
+ qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
+ }
+
+ forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
+ cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
+ outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
+
+ qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ qLstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
+ qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
+ }
+
+ IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str());
+
+ armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo);
+
+ armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]);
+ layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo);
+
+ armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]);
+ layer->GetOutputSlot(2).SetTensorInfo(outputInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index f7e47cc8c2..d6ceced7c6 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -24,6 +24,7 @@ public:
using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
+ using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
@@ -62,6 +63,7 @@ public:
static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
LstmInputParamsPtr lstmInputParams);
+ static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -113,6 +115,7 @@ private:
void ParsePermute(GraphPtr graph, unsigned int layerIndex);
void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
+ void ParseQLstm(GraphPtr graph, unsigned int layerIndex);
void ParseQuantize(GraphPtr graph, unsigned int layerIndex);
void ParseReshape(GraphPtr graph, unsigned int layerIndex);
void ParseResize(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 491cd1de42..c47810c11e 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -42,6 +42,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Pooling2d
* Prelu
* Quantize
+* QLstm
* QuantizedLstm
* Reshape
* Resize
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index ff79f6cffe..6e5ee3f3d3 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -155,7 +155,8 @@ enum LayerType : uint {
Comparison = 52,
StandIn = 53,
ElementwiseUnary = 54,
- Transpose = 55
+ Transpose = 55,
+ QLstm = 56
}
// Base layer table to be used as part of other layers
@@ -666,37 +667,96 @@ table LstmInputParams {
outputLayerNormWeights:ConstTensor;
}
-table QuantizedLstmInputParams {
- inputToInputWeights:ConstTensor;
+table LstmDescriptor {
+ activationFunc:uint;
+ clippingThresCell:float;
+ clippingThresProj:float;
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+ layerNormEnabled:bool = false;
+}
+
+table LstmLayer {
+ base:LayerBase;
+ descriptor:LstmDescriptor;
+ inputParams:LstmInputParams;
+}
+
+table QLstmInputParams {
+ // Mandatory
inputToForgetWeights:ConstTensor;
inputToCellWeights:ConstTensor;
inputToOutputWeights:ConstTensor;
- recurrentToInputWeights:ConstTensor;
recurrentToForgetWeights:ConstTensor;
recurrentToCellWeights:ConstTensor;
recurrentToOutputWeights:ConstTensor;
- inputGateBias:ConstTensor;
forgetGateBias:ConstTensor;
cellBias:ConstTensor;
outputGateBias:ConstTensor;
+
+ // CIFG
+ inputToInputWeights:ConstTensor;
+ recurrentToInputWeights:ConstTensor;
+ inputGateBias:ConstTensor;
+
+ // Projection
+ projectionWeights:ConstTensor;
+ projectionBias:ConstTensor;
+
+ // Peephole
+ cellToInputWeights:ConstTensor;
+ cellToForgetWeights:ConstTensor;
+ cellToOutputWeights:ConstTensor;
+
+ // Layer norm
+ inputLayerNormWeights:ConstTensor;
+ forgetLayerNormWeights:ConstTensor;
+ cellLayerNormWeights:ConstTensor;
+ outputLayerNormWeights:ConstTensor;
}
-table LstmDescriptor {
- activationFunc:uint;
- clippingThresCell:float;
- clippingThresProj:float;
- cifgEnabled:bool = true;
- peepholeEnabled:bool = false;
+table QLstmDescriptor {
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
projectionEnabled:bool = false;
- layerNormEnabled:bool = false;
+ layerNormEnabled:bool = false;
+
+ cellClip:float;
+ projectionClip:float;
+
+ inputIntermediateScale:float;
+ forgetIntermediateScale:float;
+ cellIntermediateScale:float;
+ outputIntermediateScale:float;
+
+ hiddenStateZeroPoint:int;
+ hiddenStateScale:float;
}
-table LstmLayer {
+table QLstmLayer {
base:LayerBase;
- descriptor:LstmDescriptor;
- inputParams:LstmInputParams;
+ descriptor:QLstmDescriptor;
+ inputParams:QLstmInputParams;
+}
+
+table QuantizedLstmInputParams {
+ inputToInputWeights:ConstTensor;
+ inputToForgetWeights:ConstTensor;
+ inputToCellWeights:ConstTensor;
+ inputToOutputWeights:ConstTensor;
+
+ recurrentToInputWeights:ConstTensor;
+ recurrentToForgetWeights:ConstTensor;
+ recurrentToCellWeights:ConstTensor;
+ recurrentToOutputWeights:ConstTensor;
+
+ inputGateBias:ConstTensor;
+ forgetGateBias:ConstTensor;
+ cellBias:ConstTensor;
+ outputGateBias:ConstTensor;
}
table QuantizedLstmLayer {
@@ -836,7 +896,8 @@ union Layer {
ComparisonLayer,
StandInLayer,
ElementwiseUnaryLayer,
- TransposeLayer
+ TransposeLayer,
+ QLstmLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 355673697f..c4d3cfb5dd 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
const armnn::LstmInputParams& params,
const char* name)
{
- IgnoreUnused(layer, descriptor, params, name);
+ IgnoreUnused(name);
+
+ auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
+
+ auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled,
+ descriptor.m_CellClip,
+ descriptor.m_ProjectionClip,
+ descriptor.m_InputIntermediateScale,
+ descriptor.m_ForgetIntermediateScale,
+ descriptor.m_CellIntermediateScale,
+ descriptor.m_OutputIntermediateScale,
+ descriptor.m_HiddenStateZeroPoint,
+ descriptor.m_HiddenStateScale
+ );
+
+ // Mandatory params
+ auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+ auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+ auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+ auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+ auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+ auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+ auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+ auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+ auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+ // CIFG
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+ recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+ inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+ }
+
+ // Projectiom
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+ projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+ }
+
+ // Peephole
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
+ }
+
+ cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+ cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+ }
+
+ // Layer norm
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
+
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+ }
+
+ forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+ cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+ outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+ }
+
+ auto fbQLstmParams = serializer::CreateQLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToInputWeights,
+ cellToForgetWeights,
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
+
+ auto fbQLstmLayer = serializer::CreateQLstmLayer(
+ m_flatBufferBuilder,
+ fbQLstmBaseLayer,
+ fbQLstmDescriptor,
+ fbQLstmParams);
- throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented");
+ CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
}
void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 79b551f03a..8ba164c91e 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -40,6 +40,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Permute
* Pooling2d
* Prelu
+* QLstm
* Quantize
* QuantizedLstm
* Reshape
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index db89430439..76ac5a4de2 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -4326,4 +4326,661 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm)
deserializedNetwork->Accept(checker);
}
+class VerifyQLstmLayer : public LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor>
+{
+public:
+ VerifyQLstmLayer(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const armnn::QLstmDescriptor& descriptor,
+ const armnn::LstmInputParams& inputParams)
+ : LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor>(layerName, inputInfos, outputInfos, descriptor)
+ , m_InputParams(inputParams) {}
+
+ void VisitQLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::QLstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name)
+ {
+ VerifyNameAndConnections(layer, name);
+ VerifyDescriptor(descriptor);
+ VerifyInputParameters(params);
+ }
+
+protected:
+ void VerifyInputParameters(const armnn::LstmInputParams& params)
+ {
+ VerifyConstTensors(
+ "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights);
+ VerifyConstTensors(
+ "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights);
+ VerifyConstTensors(
+ "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights);
+ VerifyConstTensors(
+ "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights);
+ VerifyConstTensors(
+ "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights);
+ VerifyConstTensors(
+ "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights);
+ VerifyConstTensors(
+ "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights);
+ VerifyConstTensors(
+ "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights);
+ VerifyConstTensors(
+ "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights);
+ VerifyConstTensors(
+ "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias);
+ VerifyConstTensors(
+ "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias);
+ VerifyConstTensors(
+ "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias);
+ VerifyConstTensors(
+ "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias);
+ VerifyConstTensors(
+ "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights);
+ VerifyConstTensors(
+ "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias);
+ VerifyConstTensors(
+ "m_InputLayerNormWeights", m_InputParams.m_InputLayerNormWeights, params.m_InputLayerNormWeights);
+ VerifyConstTensors(
+ "m_ForgetLayerNormWeights", m_InputParams.m_ForgetLayerNormWeights, params.m_ForgetLayerNormWeights);
+ VerifyConstTensors(
+ "m_CellLayerNormWeights", m_InputParams.m_CellLayerNormWeights, params.m_CellLayerNormWeights);
+ VerifyConstTensors(
+ "m_OutputLayerNormWeights", m_InputParams.m_OutputLayerNormWeights, params.m_OutputLayerNormWeights);
+ }
+
+private:
+ armnn::LstmInputParams m_InputParams;
+};
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ descriptor.m_CifgEnabled = true;
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = false;
+ descriptor.m_LayerNormEnabled = false;
+
+ descriptor.m_CellClip = 0.0f;
+ descriptor.m_ProjectionClip = 0.0f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float biasScale = 3.5002e-05f / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
+
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ // CIFG params are used when CIFG is disabled
+ descriptor.m_CifgEnabled = true;
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = false;
+ descriptor.m_LayerNormEnabled = true;
+
+ descriptor.m_CellClip = 0.0f;
+ descriptor.m_ProjectionClip = 0.0f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float layerNormScale = 3.5002e-05f;
+ int32_t layerNormOffset = 0;
+
+ float biasScale = layerNormScale / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits},
+ armnn::DataType::Signed32,
+ biasScale,
+ biasOffset);
+
+ armnn::TensorInfo layerNormWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ layerNormScale,
+ layerNormOffset);
+
+ // Mandatory params
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // Layer Norm
+ std::vector<int16_t> forgetLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> outputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData);
+ armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData);
+ armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+
+ // Mandatory params
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // Layer Norm
+ params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ params.m_CellLayerNormWeights = &cellLayerNormWeights;
+ params.m_OutputLayerNormWeights = &outputLayerNormWeights;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ descriptor.m_CifgEnabled = false;
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_LayerNormEnabled = true;
+
+ descriptor.m_CellClip = 0.1f;
+ descriptor.m_ProjectionClip = 0.1f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float layerNormScale = 3.5002e-05f;
+ int32_t layerNormOffset = 0;
+
+ float biasScale = layerNormScale / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits},
+ armnn::DataType::Signed32,
+ biasScale,
+ biasOffset);
+
+ armnn::TensorInfo peepholeWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo layerNormWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ layerNormScale,
+ layerNormOffset);
+
+ armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ // Mandatory params
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // CIFG
+ std::vector<int8_t> inputToInputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToInputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int32_t> inputGateBiasData(numUnits, 1);
+
+ armnn::ConstTensor inputToInputWeights(inputWeightsInfo, inputToInputWeightsData);
+ armnn::ConstTensor recurrentToInputWeights(recurrentWeightsInfo, recurrentToInputWeightsData);
+ armnn::ConstTensor inputGateBias(biasInfo, inputGateBiasData);
+
+ // Peephole
+ std::vector<int16_t> cellToInputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellToForgetWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellToOutputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor cellToInputWeights(peepholeWeightsInfo, cellToInputWeightsData);
+ armnn::ConstTensor cellToForgetWeights(peepholeWeightsInfo, cellToForgetWeightsData);
+ armnn::ConstTensor cellToOutputWeights(peepholeWeightsInfo, cellToOutputWeightsData);
+
+ // Projection
+ std::vector<int8_t> projectionWeightsData = GenerateRandomData<int8_t>(projectionWeightsInfo.GetNumElements());
+ std::vector<int32_t> projectionBiasData(outputSize, 1);
+
+ armnn::ConstTensor projectionWeights(projectionWeightsInfo, projectionWeightsData);
+ armnn::ConstTensor projectionBias(biasInfo, projectionBiasData);
+
+ // Layer Norm
+ std::vector<int16_t> inputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> forgetLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> outputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputLayerNormWeights(layerNormWeightsInfo, inputLayerNormWeightsData);
+ armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData);
+ armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData);
+ armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+
+ // Mandatory params
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // CIFG
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // Peephole
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ // Projection
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // Layer Norm
+ params.m_InputLayerNormWeights = &inputLayerNormWeights;
+ params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ params.m_CellLayerNormWeights = &cellLayerNormWeights;
+ params.m_OutputLayerNormWeights = &outputLayerNormWeights;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
BOOST_AUTO_TEST_SUITE_END()