aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-23 14:47:49 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-28 12:03:02 +0100
commita0162e17c56538ee6d72ecce4c3e0836cbb34c56 (patch)
treec47230c4024d7e79cacb39dafe179cdcf4571ade
parent996f0f59e5b8a9ac73503814f7aadff4ef74cd35 (diff)
downloadarmnn-a0162e17c56538ee6d72ecce4c3e0836cbb34c56.tar.gz
MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a
-rw-r--r--docs/01_02_deserializer_serializer.dox2
-rw-r--r--src/armnnDeserializer/Deserializer.cpp133
-rw-r--r--src/armnnDeserializer/Deserializer.hpp4
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs23
-rw-r--r--src/armnnSerializer/ArmnnSchema_generated.h224
-rw-r--r--src/armnnSerializer/Serializer.cpp124
-rw-r--r--src/armnnSerializer/Serializer.hpp5
-rw-r--r--src/armnnSerializer/test/LstmSerializationTests.cpp506
8 files changed, 1008 insertions, 13 deletions
diff --git a/docs/01_02_deserializer_serializer.dox b/docs/01_02_deserializer_serializer.dox
index 811af1cfb5..81b4ad0e9b 100644
--- a/docs/01_02_deserializer_serializer.dox
+++ b/docs/01_02_deserializer_serializer.dox
@@ -76,6 +76,7 @@ The Arm NN SDK Serializer currently supports the following layers:
- Switch
- Transpose
- TransposeConvolution2d
+- UnidirectionalSequenceLstm
More machine learning layers will be supported in future releases.
@@ -163,6 +164,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
- Switch
- Transpose
- TransposeConvolution2d
+- UnidirectionalSequenceLstm
More machine learning layers will be supported in future releases.
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index af6ff842a7..2d9194a350 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -270,6 +270,7 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte
m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch;
m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d;
m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose;
+ m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm;
}
LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
@@ -404,6 +405,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr
return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base();
case Layer::Layer_TransposeLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base();
+ case Layer::Layer_UnidirectionalSequenceLstmLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base();
case Layer::Layer_NONE:
default:
throw ParseException(fmt::format("Layer type {} not recognized", layerType));
@@ -3325,4 +3328,134 @@ void IDeserializer::DeserializerImpl::ParseStandIn(GraphPtr graph, unsigned int
RegisterOutputSlots(graph, layerIndex, layer);
}
+armnn::UnidirectionalSequenceLstmDescriptor IDeserializer::DeserializerImpl::GetUnidirectionalSequenceLstmDescriptor(
+ UnidirectionalSequenceLstmDescriptorPtr descriptor)
+{
+ armnn::UnidirectionalSequenceLstmDescriptor desc;
+
+ desc.m_ActivationFunc = descriptor->activationFunc();
+ desc.m_ClippingThresCell = descriptor->clippingThresCell();
+ desc.m_ClippingThresProj = descriptor->clippingThresProj();
+ desc.m_CifgEnabled = descriptor->cifgEnabled();
+ desc.m_PeepholeEnabled = descriptor->peepholeEnabled();
+ desc.m_ProjectionEnabled = descriptor->projectionEnabled();
+ desc.m_LayerNormEnabled = descriptor->layerNormEnabled();
+ desc.m_TimeMajor = descriptor->timeMajor();
+
+ return desc;
+}
+
+void IDeserializer::DeserializerImpl::ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 3);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer();
+ auto layerName = GetLayerName(graph, layerIndex);
+ auto flatBufferDescriptor = flatBufferLayer->descriptor();
+ auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+ auto descriptor = GetUnidirectionalSequenceLstmDescriptor(flatBufferDescriptor);
+
+ armnn::LstmInputParams lstmInputParams;
+
+ armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+ armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+ armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+ armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+ armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+ armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+ armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
+ armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
+ armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+ lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+ lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+ lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+ lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ lstmInputParams.m_ForgetGateBias = &forgetGateBias;
+ lstmInputParams.m_CellBias = &cellBias;
+ lstmInputParams.m_OutputGateBias = &outputGateBias;
+
+ armnn::ConstTensor inputToInputWeights;
+ armnn::ConstTensor recurrentToInputWeights;
+ armnn::ConstTensor cellToInputWeights;
+ armnn::ConstTensor inputGateBias;
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+ recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+ inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+ lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
+ lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ lstmInputParams.m_InputGateBias = &inputGateBias;
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+ lstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+ }
+ }
+
+ armnn::ConstTensor projectionWeights;
+ armnn::ConstTensor projectionBias;
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+ projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+ lstmInputParams.m_ProjectionWeights = &projectionWeights;
+ lstmInputParams.m_ProjectionBias = &projectionBias;
+ }
+
+ armnn::ConstTensor cellToForgetWeights;
+ armnn::ConstTensor cellToOutputWeights;
+ if (descriptor.m_PeepholeEnabled)
+ {
+ cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+ cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+ lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+ lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+ }
+
+ armnn::ConstTensor inputLayerNormWeights;
+ armnn::ConstTensor forgetLayerNormWeights;
+ armnn::ConstTensor cellLayerNormWeights;
+ armnn::ConstTensor outputLayerNormWeights;
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
+ lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
+ }
+ forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
+ cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
+ outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
+
+ lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
+ lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
+ }
+
+ IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(descriptor,
+ lstmInputParams,
+ layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
} // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 0b05e16849..b1362c44b6 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -28,6 +28,7 @@ using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
+using UnidirectionalSequenceLstmDescriptorPtr = const armnnSerializer::UnidirectionalSequenceLstmDescriptor *;
class IDeserializer::DeserializerImpl
{
@@ -67,6 +68,8 @@ public:
static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
LstmInputParamsPtr lstmInputParams);
static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
+ static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor(
+ UnidirectionalSequenceLstmDescriptorPtr descriptor);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -138,6 +141,7 @@ private:
void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
void ParseTranspose(GraphPtr graph, unsigned int layerIndex);
void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex);
+ void ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex);
void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
armnn::IConnectableLayer* layer);
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 32a9bba5ab..a544161c53 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -172,7 +172,8 @@ enum LayerType : uint {
LogicalBinary = 59,
Reduce = 60,
Cast = 61,
- Shape = 62
+ Shape = 62,
+ UnidirectionalSequenceLstm = 63,
}
// Base layer table to be used as part of other layers
@@ -915,6 +916,23 @@ table ReduceDescriptor {
reduceOperation:ReduceOperation = Sum;
}
+table UnidirectionalSequenceLstmDescriptor {
+ activationFunc:uint;
+ clippingThresCell:float;
+ clippingThresProj:float;
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+ layerNormEnabled:bool = false;
+ timeMajor:bool = false;
+}
+
+table UnidirectionalSequenceLstmLayer {
+ base:LayerBase;
+ descriptor:UnidirectionalSequenceLstmDescriptor;
+ inputParams:LstmInputParams;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -978,7 +996,8 @@ union Layer {
LogicalBinaryLayer,
ReduceLayer,
CastLayer,
- ShapeLayer
+ ShapeLayer,
+ UnidirectionalSequenceLstmLayer,
}
table AnyLayer {
diff --git a/src/armnnSerializer/ArmnnSchema_generated.h b/src/armnnSerializer/ArmnnSchema_generated.h
index 4a352ddb6c..27550f0682 100644
--- a/src/armnnSerializer/ArmnnSchema_generated.h
+++ b/src/armnnSerializer/ArmnnSchema_generated.h
@@ -362,6 +362,12 @@ struct ReduceLayerBuilder;
struct ReduceDescriptor;
struct ReduceDescriptorBuilder;
+struct UnidirectionalSequenceLstmDescriptor;
+struct UnidirectionalSequenceLstmDescriptorBuilder;
+
+struct UnidirectionalSequenceLstmLayer;
+struct UnidirectionalSequenceLstmLayerBuilder;
+
struct AnyLayer;
struct AnyLayerBuilder;
@@ -740,11 +746,12 @@ enum LayerType {
LayerType_Reduce = 60,
LayerType_Cast = 61,
LayerType_Shape = 62,
+ LayerType_UnidirectionalSequenceLstm = 63,
LayerType_MIN = LayerType_Addition,
- LayerType_MAX = LayerType_Shape
+ LayerType_MAX = LayerType_UnidirectionalSequenceLstm
};
-inline const LayerType (&EnumValuesLayerType())[63] {
+inline const LayerType (&EnumValuesLayerType())[64] {
static const LayerType values[] = {
LayerType_Addition,
LayerType_Input,
@@ -808,13 +815,14 @@ inline const LayerType (&EnumValuesLayerType())[63] {
LayerType_LogicalBinary,
LayerType_Reduce,
LayerType_Cast,
- LayerType_Shape
+ LayerType_Shape,
+ LayerType_UnidirectionalSequenceLstm
};
return values;
}
inline const char * const *EnumNamesLayerType() {
- static const char * const names[64] = {
+ static const char * const names[65] = {
"Addition",
"Input",
"Multiplication",
@@ -878,13 +886,14 @@ inline const char * const *EnumNamesLayerType() {
"Reduce",
"Cast",
"Shape",
+ "UnidirectionalSequenceLstm",
nullptr
};
return names;
}
inline const char *EnumNameLayerType(LayerType e) {
- if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_Shape)) return "";
+ if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_UnidirectionalSequenceLstm)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesLayerType()[index];
}
@@ -1227,11 +1236,12 @@ enum Layer {
Layer_ReduceLayer = 61,
Layer_CastLayer = 62,
Layer_ShapeLayer = 63,
+ Layer_UnidirectionalSequenceLstmLayer = 64,
Layer_MIN = Layer_NONE,
- Layer_MAX = Layer_ShapeLayer
+ Layer_MAX = Layer_UnidirectionalSequenceLstmLayer
};
-inline const Layer (&EnumValuesLayer())[64] {
+inline const Layer (&EnumValuesLayer())[65] {
static const Layer values[] = {
Layer_NONE,
Layer_ActivationLayer,
@@ -1296,13 +1306,14 @@ inline const Layer (&EnumValuesLayer())[64] {
Layer_LogicalBinaryLayer,
Layer_ReduceLayer,
Layer_CastLayer,
- Layer_ShapeLayer
+ Layer_ShapeLayer,
+ Layer_UnidirectionalSequenceLstmLayer
};
return values;
}
inline const char * const *EnumNamesLayer() {
- static const char * const names[65] = {
+ static const char * const names[66] = {
"NONE",
"ActivationLayer",
"AdditionLayer",
@@ -1367,13 +1378,14 @@ inline const char * const *EnumNamesLayer() {
"ReduceLayer",
"CastLayer",
"ShapeLayer",
+ "UnidirectionalSequenceLstmLayer",
nullptr
};
return names;
}
inline const char *EnumNameLayer(Layer e) {
- if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_ShapeLayer)) return "";
+ if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_UnidirectionalSequenceLstmLayer)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesLayer()[index];
}
@@ -1634,6 +1646,10 @@ template<> struct LayerTraits<armnnSerializer::ShapeLayer> {
static const Layer enum_value = Layer_ShapeLayer;
};
+template<> struct LayerTraits<armnnSerializer::UnidirectionalSequenceLstmLayer> {
+ static const Layer enum_value = Layer_UnidirectionalSequenceLstmLayer;
+};
+
bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer type);
bool VerifyLayerVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
@@ -9425,6 +9441,183 @@ inline flatbuffers::Offset<ReduceDescriptor> CreateReduceDescriptorDirect(
reduceOperation);
}
+struct UnidirectionalSequenceLstmDescriptor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnidirectionalSequenceLstmDescriptorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ACTIVATIONFUNC = 4,
+ VT_CLIPPINGTHRESCELL = 6,
+ VT_CLIPPINGTHRESPROJ = 8,
+ VT_CIFGENABLED = 10,
+ VT_PEEPHOLEENABLED = 12,
+ VT_PROJECTIONENABLED = 14,
+ VT_LAYERNORMENABLED = 16,
+ VT_TIMEMAJOR = 18
+ };
+ uint32_t activationFunc() const {
+ return GetField<uint32_t>(VT_ACTIVATIONFUNC, 0);
+ }
+ float clippingThresCell() const {
+ return GetField<float>(VT_CLIPPINGTHRESCELL, 0.0f);
+ }
+ float clippingThresProj() const {
+ return GetField<float>(VT_CLIPPINGTHRESPROJ, 0.0f);
+ }
+ bool cifgEnabled() const {
+ return GetField<uint8_t>(VT_CIFGENABLED, 1) != 0;
+ }
+ bool peepholeEnabled() const {
+ return GetField<uint8_t>(VT_PEEPHOLEENABLED, 0) != 0;
+ }
+ bool projectionEnabled() const {
+ return GetField<uint8_t>(VT_PROJECTIONENABLED, 0) != 0;
+ }
+ bool layerNormEnabled() const {
+ return GetField<uint8_t>(VT_LAYERNORMENABLED, 0) != 0;
+ }
+ bool timeMajor() const {
+ return GetField<uint8_t>(VT_TIMEMAJOR, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_ACTIVATIONFUNC) &&
+ VerifyField<float>(verifier, VT_CLIPPINGTHRESCELL) &&
+ VerifyField<float>(verifier, VT_CLIPPINGTHRESPROJ) &&
+ VerifyField<uint8_t>(verifier, VT_CIFGENABLED) &&
+ VerifyField<uint8_t>(verifier, VT_PEEPHOLEENABLED) &&
+ VerifyField<uint8_t>(verifier, VT_PROJECTIONENABLED) &&
+ VerifyField<uint8_t>(verifier, VT_LAYERNORMENABLED) &&
+ VerifyField<uint8_t>(verifier, VT_TIMEMAJOR) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnidirectionalSequenceLstmDescriptorBuilder {
+ typedef UnidirectionalSequenceLstmDescriptor Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_activationFunc(uint32_t activationFunc) {
+ fbb_.AddElement<uint32_t>(UnidirectionalSequenceLstmDescriptor::VT_ACTIVATIONFUNC, activationFunc, 0);
+ }
+ void add_clippingThresCell(float clippingThresCell) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLstmDescriptor::VT_CLIPPINGTHRESCELL, clippingThresCell, 0.0f);
+ }
+ void add_clippingThresProj(float clippingThresProj) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLstmDescriptor::VT_CLIPPINGTHRESPROJ, clippingThresProj, 0.0f);
+ }
+ void add_cifgEnabled(bool cifgEnabled) {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLstmDescriptor::VT_CIFGENABLED, static_cast<uint8_t>(cifgEnabled), 1);
+ }
+ void add_peepholeEnabled(bool peepholeEnabled) {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLstmDescriptor::VT_PEEPHOLEENABLED, static_cast<uint8_t>(peepholeEnabled), 0);
+ }
+ void add_projectionEnabled(bool projectionEnabled) {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLstmDescriptor::VT_PROJECTIONENABLED, static_cast<uint8_t>(projectionEnabled), 0);
+ }
+ void add_layerNormEnabled(bool layerNormEnabled) {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLstmDescriptor::VT_LAYERNORMENABLED, static_cast<uint8_t>(layerNormEnabled), 0);
+ }
+ void add_timeMajor(bool timeMajor) {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLstmDescriptor::VT_TIMEMAJOR, static_cast<uint8_t>(timeMajor), 0);
+ }
+ explicit UnidirectionalSequenceLstmDescriptorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnidirectionalSequenceLstmDescriptorBuilder &operator=(const UnidirectionalSequenceLstmDescriptorBuilder &);
+ flatbuffers::Offset<UnidirectionalSequenceLstmDescriptor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnidirectionalSequenceLstmDescriptor>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnidirectionalSequenceLstmDescriptor> CreateUnidirectionalSequenceLstmDescriptor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t activationFunc = 0,
+ float clippingThresCell = 0.0f,
+ float clippingThresProj = 0.0f,
+ bool cifgEnabled = true,
+ bool peepholeEnabled = false,
+ bool projectionEnabled = false,
+ bool layerNormEnabled = false,
+ bool timeMajor = false) {
+ UnidirectionalSequenceLstmDescriptorBuilder builder_(_fbb);
+ builder_.add_clippingThresProj(clippingThresProj);
+ builder_.add_clippingThresCell(clippingThresCell);
+ builder_.add_activationFunc(activationFunc);
+ builder_.add_timeMajor(timeMajor);
+ builder_.add_layerNormEnabled(layerNormEnabled);
+ builder_.add_projectionEnabled(projectionEnabled);
+ builder_.add_peepholeEnabled(peepholeEnabled);
+ builder_.add_cifgEnabled(cifgEnabled);
+ return builder_.Finish();
+}
+
+struct UnidirectionalSequenceLstmLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnidirectionalSequenceLstmLayerBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_DESCRIPTOR = 6,
+ VT_INPUTPARAMS = 8
+ };
+ const armnnSerializer::LayerBase *base() const {
+ return GetPointer<const armnnSerializer::LayerBase *>(VT_BASE);
+ }
+ const armnnSerializer::UnidirectionalSequenceLstmDescriptor *descriptor() const {
+ return GetPointer<const armnnSerializer::UnidirectionalSequenceLstmDescriptor *>(VT_DESCRIPTOR);
+ }
+ const armnnSerializer::LstmInputParams *inputParams() const {
+ return GetPointer<const armnnSerializer::LstmInputParams *>(VT_INPUTPARAMS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffset(verifier, VT_DESCRIPTOR) &&
+ verifier.VerifyTable(descriptor()) &&
+ VerifyOffset(verifier, VT_INPUTPARAMS) &&
+ verifier.VerifyTable(inputParams()) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnidirectionalSequenceLstmLayerBuilder {
+ typedef UnidirectionalSequenceLstmLayer Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<armnnSerializer::LayerBase> base) {
+ fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_BASE, base);
+ }
+ void add_descriptor(flatbuffers::Offset<armnnSerializer::UnidirectionalSequenceLstmDescriptor> descriptor) {
+ fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_DESCRIPTOR, descriptor);
+ }
+ void add_inputParams(flatbuffers::Offset<armnnSerializer::LstmInputParams> inputParams) {
+ fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_INPUTPARAMS, inputParams);
+ }
+ explicit UnidirectionalSequenceLstmLayerBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnidirectionalSequenceLstmLayerBuilder &operator=(const UnidirectionalSequenceLstmLayerBuilder &);
+ flatbuffers::Offset<UnidirectionalSequenceLstmLayer> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnidirectionalSequenceLstmLayer>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnidirectionalSequenceLstmLayer> CreateUnidirectionalSequenceLstmLayer(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<armnnSerializer::LayerBase> base = 0,
+ flatbuffers::Offset<armnnSerializer::UnidirectionalSequenceLstmDescriptor> descriptor = 0,
+ flatbuffers::Offset<armnnSerializer::LstmInputParams> inputParams = 0) {
+ UnidirectionalSequenceLstmLayerBuilder builder_(_fbb);
+ builder_.add_inputParams(inputParams);
+ builder_.add_descriptor(descriptor);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
struct AnyLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef AnyLayerBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -9627,6 +9820,9 @@ struct AnyLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const armnnSerializer::ShapeLayer *layer_as_ShapeLayer() const {
return layer_type() == armnnSerializer::Layer_ShapeLayer ? static_cast<const armnnSerializer::ShapeLayer *>(layer()) : nullptr;
}
+ const armnnSerializer::UnidirectionalSequenceLstmLayer *layer_as_UnidirectionalSequenceLstmLayer() const {
+ return layer_type() == armnnSerializer::Layer_UnidirectionalSequenceLstmLayer ? static_cast<const armnnSerializer::UnidirectionalSequenceLstmLayer *>(layer()) : nullptr;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_LAYER_TYPE) &&
@@ -9888,6 +10084,10 @@ template<> inline const armnnSerializer::ShapeLayer *AnyLayer::layer_as<armnnSer
return layer_as_ShapeLayer();
}
+template<> inline const armnnSerializer::UnidirectionalSequenceLstmLayer *AnyLayer::layer_as<armnnSerializer::UnidirectionalSequenceLstmLayer>() const {
+ return layer_as_UnidirectionalSequenceLstmLayer();
+}
+
struct AnyLayerBuilder {
typedef AnyLayer Table;
flatbuffers::FlatBufferBuilder &fbb_;
@@ -10360,6 +10560,10 @@ inline bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer
auto ptr = reinterpret_cast<const armnnSerializer::ShapeLayer *>(obj);
return verifier.VerifyTable(ptr);
}
+ case Layer_UnidirectionalSequenceLstmLayer: {
+ auto ptr = reinterpret_cast<const armnnSerializer::UnidirectionalSequenceLstmLayer *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return true;
}
}
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index fd7f8dc7dc..44cd1800c4 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1648,6 +1648,123 @@ void SerializerStrategy::SerializeQuantizedLstmLayer(const armnn::IConnectableLa
CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
}
+void SerializerStrategy::SerializeUnidirectionalSequenceLstmLayer(
+ const armnn::IConnectableLayer* layer,
+ const armnn::UnidirectionalSequenceLstmDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ auto fbUnidirectionalSequenceLstmBaseLayer =
+ CreateLayerBase(layer, serializer::LayerType::LayerType_UnidirectionalSequenceLstm);
+
+ auto fbUnidirectionalSequenceLstmDescriptor = serializer::CreateUnidirectionalSequenceLstmDescriptor(
+ m_flatBufferBuilder,
+ descriptor.m_ActivationFunc,
+ descriptor.m_ClippingThresCell,
+ descriptor.m_ClippingThresProj,
+ descriptor.m_CifgEnabled,
+ descriptor.m_PeepholeEnabled,
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled,
+ descriptor.m_TimeMajor);
+
+ // Index for constants vector
+ std::size_t i = 0;
+
+ // Get mandatory/basic input parameters
+ auto inputToForgetWeights = CreateConstTensorInfo(constants[i++]); //InputToForgetWeights
+ auto inputToCellWeights = CreateConstTensorInfo(constants[i++]); //InputToCellWeights
+ auto inputToOutputWeights = CreateConstTensorInfo(constants[i++]); //InputToOutputWeights
+ auto recurrentToForgetWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToForgetWeights
+ auto recurrentToCellWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToCellWeights
+ auto recurrentToOutputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToOutputWeights
+ auto forgetGateBias = CreateConstTensorInfo(constants[i++]); //ForgetGateBias
+ auto cellBias = CreateConstTensorInfo(constants[i++]); //CellBias
+ auto outputGateBias = CreateConstTensorInfo(constants[i++]); //OutputGateBias
+
+ //Define optional parameters, these will be set depending on configuration in Lstm descriptor
+ flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+ flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+ flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+ flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
+
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = CreateConstTensorInfo(constants[i++]); //InputToInputWeights
+ recurrentToInputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToInputWeights
+ inputGateBias = CreateConstTensorInfo(constants[i++]); //InputGateBias
+ }
+
+ if (descriptor.m_PeepholeEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = CreateConstTensorInfo(constants[i++]); //CellToInputWeights
+ }
+ cellToForgetWeights = CreateConstTensorInfo(constants[i++]); //CellToForgetWeights
+ cellToOutputWeights = CreateConstTensorInfo(constants[i++]); //CellToOutputWeights
+ }
+
+ if (descriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = CreateConstTensorInfo(constants[i++]); //ProjectionWeights
+ projectionBias = CreateConstTensorInfo(constants[i++]); //ProjectionBias
+ }
+
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //InputLayerNormWeights
+ }
+ forgetLayerNormWeights = CreateConstTensorInfo(constants[i++]); //ForgetLayerNormWeights
+ cellLayerNormWeights = CreateConstTensorInfo(constants[i++]); //CellLayerNormWeights
+ outputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //OutputLayerNormWeights
+ }
+
+ auto fbUnidirectionalSequenceLstmParams = serializer::CreateLstmInputParams(
+ m_flatBufferBuilder,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ cellToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToForgetWeights,
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
+
+ auto fbUnidirectionalSequenceLstmLayer = serializer::CreateUnidirectionalSequenceLstmLayer(
+ m_flatBufferBuilder,
+ fbUnidirectionalSequenceLstmBaseLayer,
+ fbUnidirectionalSequenceLstmDescriptor,
+ fbUnidirectionalSequenceLstmParams);
+
+ CreateAnyLayer(fbUnidirectionalSequenceLstmLayer.o, serializer::Layer::Layer_UnidirectionalSequenceLstmLayer);
+}
+
fb::Offset<serializer::LayerBase> SerializerStrategy::CreateLayerBase(const IConnectableLayer* layer,
const serializer::LayerType layerType)
{
@@ -2234,6 +2351,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
SerializeTransposeConvolution2dLayer(layer, layerDescriptor, constants, name);
break;
}
+ case armnn::LayerType::UnidirectionalSequenceLstm :
+ {
+ const armnn::UnidirectionalSequenceLstmDescriptor& layerDescriptor =
+ static_cast<const armnn::UnidirectionalSequenceLstmDescriptor&>(descriptor);
+ SerializeUnidirectionalSequenceLstmLayer(layer, layerDescriptor, constants, name);
+ break;
+ }
default:
{
throw InvalidArgumentException(
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index c99e87d3e9..dead8739cc 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -348,6 +348,11 @@ private:
void SerializeTransposeLayer(const armnn::IConnectableLayer* layer,
const armnn::TransposeDescriptor& descriptor,
const char* name = nullptr);
+
+ void SerializeUnidirectionalSequenceLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::UnidirectionalSequenceLstmDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name = nullptr);
};
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp
index c2bc8737b4..bdc37877f7 100644
--- a/src/armnnSerializer/test/LstmSerializationTests.cpp
+++ b/src/armnnSerializer/test/LstmSerializationTests.cpp
@@ -74,7 +74,7 @@ armnn::LstmInputParams ConstantVector2LstmInputParams(const std::vector<armnn::C
return lstmInputParams;
}
-// Works for Lstm and QLstm (QuantizedLstm uses different parameters)
+// Works for Lstm, UnidirectionalSequenceLstm and QLstm (QuantizedLstm uses different parameters)
template<typename Descriptor>
class VerifyLstmLayer : public LayerVerifierBaseWithDescriptor<Descriptor>
{
@@ -99,6 +99,7 @@ public:
case armnn::LayerType::Input: break;
case armnn::LayerType::Output: break;
case armnn::LayerType::Lstm:
+ case armnn::LayerType::UnidirectionalSequenceLstm:
{
this->VerifyNameAndConnections(layer, name);
const Descriptor& internalDescriptor = static_cast<const Descriptor&>(descriptor);
@@ -2195,4 +2196,507 @@ TEST_CASE("SerializeDeserializeQLstmAdvanced")
deserializedNetwork->ExecuteStrategy(checker);
}
+TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjection")
+{
+ armnn::UnidirectionalSequenceLstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_TimeMajor = false;
+
+ const uint32_t batchSize = 1;
+ const uint32_t timeSize = 2;
+ const uint32_t inputSize = 2;
+ const uint32_t numUnits = 4;
+ const uint32_t outputSize = numUnits;
+
+ armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32);
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData);
+
+ std::vector<float> forgetGateBiasData(numUnits, 1.0f);
+ armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData);
+
+ std::vector<float> cellBiasData(numUnits, 0.0f);
+ armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData);
+
+ std::vector<float> outputGateBiasData(numUnits, 0.0f);
+ armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("UnidirectionalSequenceLstm");
+ armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer =
+ network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {outputTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->ExecuteStrategy(checker);
+}
+
+TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndProjection")
+{
+ armnn::UnidirectionalSequenceLstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_TimeMajor = false;
+
+ const uint32_t batchSize = 2;
+ const uint32_t timeSize = 2;
+ const uint32_t inputSize = 5;
+ const uint32_t numUnits = 20;
+ const uint32_t outputSize = 16;
+
+ armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData);
+
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32);
+ std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData);
+
+ std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellBias(tensorInfo20, cellBiasData);
+
+ std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData);
+
+ armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData);
+
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData);
+
+ std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData);
+
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32);
+ std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements());
+ armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData);
+
+ armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32);
+ std::vector<float> projectionBiasData(outputSize, 0.f);
+ armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // additional params because: descriptor.m_CifgEnabled = false
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // additional params because: descriptor.m_ProjectionEnabled = true
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // additional params because: descriptor.m_PeepholeEnabled = true
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("unidirectionalSequenceLstm");
+ armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer =
+ network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {outputTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->ExecuteStrategy(checker);
+}
+
+TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNorm")
+{
+ armnn::UnidirectionalSequenceLstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_LayerNormEnabled = true;
+ descriptor.m_TimeMajor = false;
+
+ const uint32_t batchSize = 2;
+ const uint32_t timeSize = 2;
+ const uint32_t inputSize = 5;
+ const uint32_t numUnits = 20;
+ const uint32_t outputSize = 16;
+
+ armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData);
+
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32);
+ std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData);
+
+ std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellBias(tensorInfo20, cellBiasData);
+
+ std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData);
+
+ armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData);
+
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData);
+
+ std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData);
+
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32);
+ std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements());
+ armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData);
+
+ armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32);
+ std::vector<float> projectionBiasData(outputSize, 0.f);
+ armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData);
+
+ std::vector<float> inputLayerNormWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor inputLayerNormWeights(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> forgetLayerNormWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor forgetLayerNormWeights(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> cellLayerNormWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellLayerNormWeights(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> outLayerNormWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor outLayerNormWeights(tensorInfo20, forgetGateBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // additional params because: descriptor.m_CifgEnabled = false
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // additional params because: descriptor.m_ProjectionEnabled = true
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // additional params because: descriptor.m_PeepholeEnabled = true
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ // additional params because: despriptor.m_LayerNormEnabled = true
+ params.m_InputLayerNormWeights = &inputLayerNormWeights;
+ params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ params.m_CellLayerNormWeights = &cellLayerNormWeights;
+ params.m_OutputLayerNormWeights = &outLayerNormWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("unidirectionalSequenceLstm");
+ armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer =
+ network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {outputTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->ExecuteStrategy(checker);
+}
+
+TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectionTimeMajor")
+{
+ armnn::UnidirectionalSequenceLstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_TimeMajor = true;
+
+ const uint32_t batchSize = 1;
+ const uint32_t timeSize = 2;
+ const uint32_t inputSize = 2;
+ const uint32_t numUnits = 4;
+ const uint32_t outputSize = numUnits;
+
+ armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32);
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData);
+
+ std::vector<float> forgetGateBiasData(numUnits, 1.0f);
+ armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData);
+
+ std::vector<float> cellBiasData(numUnits, 0.0f);
+ armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData);
+
+ std::vector<float> outputGateBiasData(numUnits, 0.0f);
+ armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("UnidirectionalSequenceLstm");
+ armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer =
+ network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {outputTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->ExecuteStrategy(checker);
+}
+
}