From a0162e17c56538ee6d72ecce4c3e0836cbb34c56 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 23 Jul 2021 14:47:49 +0100 Subject: MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm Signed-off-by: Narumol Prangnawarat Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a --- docs/01_02_deserializer_serializer.dox | 2 + src/armnnDeserializer/Deserializer.cpp | 133 ++++++ src/armnnDeserializer/Deserializer.hpp | 4 + src/armnnSerializer/ArmnnSchema.fbs | 23 +- src/armnnSerializer/ArmnnSchema_generated.h | 224 ++++++++- src/armnnSerializer/Serializer.cpp | 124 +++++ src/armnnSerializer/Serializer.hpp | 5 + .../test/LstmSerializationTests.cpp | 506 ++++++++++++++++++++- 8 files changed, 1008 insertions(+), 13 deletions(-) diff --git a/docs/01_02_deserializer_serializer.dox b/docs/01_02_deserializer_serializer.dox index 811af1cfb5..81b4ad0e9b 100644 --- a/docs/01_02_deserializer_serializer.dox +++ b/docs/01_02_deserializer_serializer.dox @@ -76,6 +76,7 @@ The Arm NN SDK Serializer currently supports the following layers: - Switch - Transpose - TransposeConvolution2d +- UnidirectionalSequenceLstm More machine learning layers will be supported in future releases. @@ -163,6 +164,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: - Switch - Transpose - TransposeConvolution2d +- UnidirectionalSequenceLstm More machine learning layers will be supported in future releases. diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index af6ff842a7..2d9194a350 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -270,6 +270,7 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch; m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d; m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose; + m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm; } LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -404,6 +405,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base(); case Layer::Layer_TransposeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base(); + case Layer::Layer_UnidirectionalSequenceLstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(fmt::format("Layer type {} not recognized", layerType)); @@ -3325,4 +3328,134 @@ void IDeserializer::DeserializerImpl::ParseStandIn(GraphPtr graph, unsigned int RegisterOutputSlots(graph, layerIndex, layer); } +armnn::UnidirectionalSequenceLstmDescriptor IDeserializer::DeserializerImpl::GetUnidirectionalSequenceLstmDescriptor( + UnidirectionalSequenceLstmDescriptorPtr descriptor) +{ + armnn::UnidirectionalSequenceLstmDescriptor desc; + + desc.m_ActivationFunc = descriptor->activationFunc(); + desc.m_ClippingThresCell = descriptor->clippingThresCell(); + desc.m_ClippingThresProj = descriptor->clippingThresProj(); + desc.m_CifgEnabled = descriptor->cifgEnabled(); + desc.m_PeepholeEnabled = descriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = descriptor->projectionEnabled(); + desc.m_LayerNormEnabled = descriptor->layerNormEnabled(); + desc.m_TimeMajor = descriptor->timeMajor(); + + return desc; +} + +void IDeserializer::DeserializerImpl::ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto descriptor = GetUnidirectionalSequenceLstmDescriptor(flatBufferDescriptor); + + armnn::LstmInputParams lstmInputParams; + + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + lstmInputParams.m_InputToCellWeights = &inputToCellWeights; + lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + lstmInputParams.m_ForgetGateBias = &forgetGateBias; + lstmInputParams.m_CellBias = &cellBias; + lstmInputParams.m_OutputGateBias = &outputGateBias; + + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor inputGateBias; + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + lstmInputParams.m_InputToInputWeights = &inputToInputWeights; + lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + lstmInputParams.m_InputGateBias = &inputGateBias; + + if (descriptor.m_PeepholeEnabled) + { + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + lstmInputParams.m_CellToInputWeights = &cellToInputWeights; + } + } + + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + lstmInputParams.m_ProjectionWeights = &projectionWeights; + lstmInputParams.m_ProjectionBias = &projectionBias; + } + + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + if (descriptor.m_PeepholeEnabled) + { + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + armnn::ConstTensor inputLayerNormWeights; + armnn::ConstTensor forgetLayerNormWeights; + armnn::ConstTensor cellLayerNormWeights; + armnn::ConstTensor outputLayerNormWeights; + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights()); + lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights; + } + forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights()); + cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights()); + outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights()); + + lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights; + lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights; + } + + IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(descriptor, + lstmInputParams, + layerName.c_str()); + + armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index 0b05e16849..b1362c44b6 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -28,6 +28,7 @@ using TensorRawPtrVector = std::vector; using LayerRawPtr = const armnnSerializer::LayerBase *; using LayerBaseRawPtr = const armnnSerializer::LayerBase *; using LayerBaseRawPtrVector = std::vector; +using UnidirectionalSequenceLstmDescriptorPtr = const armnnSerializer::UnidirectionalSequenceLstmDescriptor *; class IDeserializer::DeserializerImpl { @@ -67,6 +68,8 @@ public: static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor, LstmInputParamsPtr lstmInputParams); static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr); + static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor( + UnidirectionalSequenceLstmDescriptorPtr descriptor); static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, const std::vector & targetDimsIn); @@ -138,6 +141,7 @@ private: void ParseSwitch(GraphPtr graph, unsigned int layerIndex); void ParseTranspose(GraphPtr graph, unsigned int layerIndex); void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex); + void ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex); void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex, armnn::IConnectableLayer* layer); diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 32a9bba5ab..a544161c53 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -172,7 +172,8 @@ enum LayerType : uint { LogicalBinary = 59, Reduce = 60, Cast = 61, - Shape = 62 + Shape = 62, + UnidirectionalSequenceLstm = 63, } // Base layer table to be used as part of other layers @@ -915,6 +916,23 @@ table ReduceDescriptor { reduceOperation:ReduceOperation = Sum; } +table UnidirectionalSequenceLstmDescriptor { + activationFunc:uint; + clippingThresCell:float; + clippingThresProj:float; + cifgEnabled:bool = true; + peepholeEnabled:bool = false; + projectionEnabled:bool = false; + layerNormEnabled:bool = false; + timeMajor:bool = false; +} + +table UnidirectionalSequenceLstmLayer { + base:LayerBase; + descriptor:UnidirectionalSequenceLstmDescriptor; + inputParams:LstmInputParams; +} + union Layer { ActivationLayer, AdditionLayer, @@ -978,7 +996,8 @@ union Layer { LogicalBinaryLayer, ReduceLayer, CastLayer, - ShapeLayer + ShapeLayer, + UnidirectionalSequenceLstmLayer, } table AnyLayer { diff --git a/src/armnnSerializer/ArmnnSchema_generated.h b/src/armnnSerializer/ArmnnSchema_generated.h index 4a352ddb6c..27550f0682 100644 --- a/src/armnnSerializer/ArmnnSchema_generated.h +++ b/src/armnnSerializer/ArmnnSchema_generated.h @@ -362,6 +362,12 @@ struct ReduceLayerBuilder; struct ReduceDescriptor; struct ReduceDescriptorBuilder; +struct UnidirectionalSequenceLstmDescriptor; +struct UnidirectionalSequenceLstmDescriptorBuilder; + +struct UnidirectionalSequenceLstmLayer; +struct UnidirectionalSequenceLstmLayerBuilder; + struct AnyLayer; struct AnyLayerBuilder; @@ -740,11 +746,12 @@ enum LayerType { LayerType_Reduce = 60, LayerType_Cast = 61, LayerType_Shape = 62, + LayerType_UnidirectionalSequenceLstm = 63, LayerType_MIN = LayerType_Addition, - LayerType_MAX = LayerType_Shape + LayerType_MAX = LayerType_UnidirectionalSequenceLstm }; -inline const LayerType (&EnumValuesLayerType())[63] { +inline const LayerType (&EnumValuesLayerType())[64] { static const LayerType values[] = { LayerType_Addition, LayerType_Input, @@ -808,13 +815,14 @@ inline const LayerType (&EnumValuesLayerType())[63] { LayerType_LogicalBinary, LayerType_Reduce, LayerType_Cast, - LayerType_Shape + LayerType_Shape, + LayerType_UnidirectionalSequenceLstm }; return values; } inline const char * const *EnumNamesLayerType() { - static const char * const names[64] = { + static const char * const names[65] = { "Addition", "Input", "Multiplication", @@ -878,13 +886,14 @@ inline const char * const *EnumNamesLayerType() { "Reduce", "Cast", "Shape", + "UnidirectionalSequenceLstm", nullptr }; return names; } inline const char *EnumNameLayerType(LayerType e) { - if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_Shape)) return ""; + if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_UnidirectionalSequenceLstm)) return ""; const size_t index = static_cast(e); return EnumNamesLayerType()[index]; } @@ -1227,11 +1236,12 @@ enum Layer { Layer_ReduceLayer = 61, Layer_CastLayer = 62, Layer_ShapeLayer = 63, + Layer_UnidirectionalSequenceLstmLayer = 64, Layer_MIN = Layer_NONE, - Layer_MAX = Layer_ShapeLayer + Layer_MAX = Layer_UnidirectionalSequenceLstmLayer }; -inline const Layer (&EnumValuesLayer())[64] { +inline const Layer (&EnumValuesLayer())[65] { static const Layer values[] = { Layer_NONE, Layer_ActivationLayer, @@ -1296,13 +1306,14 @@ inline const Layer (&EnumValuesLayer())[64] { Layer_LogicalBinaryLayer, Layer_ReduceLayer, Layer_CastLayer, - Layer_ShapeLayer + Layer_ShapeLayer, + Layer_UnidirectionalSequenceLstmLayer }; return values; } inline const char * const *EnumNamesLayer() { - static const char * const names[65] = { + static const char * const names[66] = { "NONE", "ActivationLayer", "AdditionLayer", @@ -1367,13 +1378,14 @@ inline const char * const *EnumNamesLayer() { "ReduceLayer", "CastLayer", "ShapeLayer", + "UnidirectionalSequenceLstmLayer", nullptr }; return names; } inline const char *EnumNameLayer(Layer e) { - if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_ShapeLayer)) return ""; + if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_UnidirectionalSequenceLstmLayer)) return ""; const size_t index = static_cast(e); return EnumNamesLayer()[index]; } @@ -1634,6 +1646,10 @@ template<> struct LayerTraits { static const Layer enum_value = Layer_ShapeLayer; }; +template<> struct LayerTraits { + static const Layer enum_value = Layer_UnidirectionalSequenceLstmLayer; +}; + bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer type); bool VerifyLayerVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); @@ -9425,6 +9441,183 @@ inline flatbuffers::Offset CreateReduceDescriptorDirect( reduceOperation); } +struct UnidirectionalSequenceLstmDescriptor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UnidirectionalSequenceLstmDescriptorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ACTIVATIONFUNC = 4, + VT_CLIPPINGTHRESCELL = 6, + VT_CLIPPINGTHRESPROJ = 8, + VT_CIFGENABLED = 10, + VT_PEEPHOLEENABLED = 12, + VT_PROJECTIONENABLED = 14, + VT_LAYERNORMENABLED = 16, + VT_TIMEMAJOR = 18 + }; + uint32_t activationFunc() const { + return GetField(VT_ACTIVATIONFUNC, 0); + } + float clippingThresCell() const { + return GetField(VT_CLIPPINGTHRESCELL, 0.0f); + } + float clippingThresProj() const { + return GetField(VT_CLIPPINGTHRESPROJ, 0.0f); + } + bool cifgEnabled() const { + return GetField(VT_CIFGENABLED, 1) != 0; + } + bool peepholeEnabled() const { + return GetField(VT_PEEPHOLEENABLED, 0) != 0; + } + bool projectionEnabled() const { + return GetField(VT_PROJECTIONENABLED, 0) != 0; + } + bool layerNormEnabled() const { + return GetField(VT_LAYERNORMENABLED, 0) != 0; + } + bool timeMajor() const { + return GetField(VT_TIMEMAJOR, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ACTIVATIONFUNC) && + VerifyField(verifier, VT_CLIPPINGTHRESCELL) && + VerifyField(verifier, VT_CLIPPINGTHRESPROJ) && + VerifyField(verifier, VT_CIFGENABLED) && + VerifyField(verifier, VT_PEEPHOLEENABLED) && + VerifyField(verifier, VT_PROJECTIONENABLED) && + VerifyField(verifier, VT_LAYERNORMENABLED) && + VerifyField(verifier, VT_TIMEMAJOR) && + verifier.EndTable(); + } +}; + +struct UnidirectionalSequenceLstmDescriptorBuilder { + typedef UnidirectionalSequenceLstmDescriptor Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_activationFunc(uint32_t activationFunc) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_ACTIVATIONFUNC, activationFunc, 0); + } + void add_clippingThresCell(float clippingThresCell) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_CLIPPINGTHRESCELL, clippingThresCell, 0.0f); + } + void add_clippingThresProj(float clippingThresProj) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_CLIPPINGTHRESPROJ, clippingThresProj, 0.0f); + } + void add_cifgEnabled(bool cifgEnabled) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_CIFGENABLED, static_cast(cifgEnabled), 1); + } + void add_peepholeEnabled(bool peepholeEnabled) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_PEEPHOLEENABLED, static_cast(peepholeEnabled), 0); + } + void add_projectionEnabled(bool projectionEnabled) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_PROJECTIONENABLED, static_cast(projectionEnabled), 0); + } + void add_layerNormEnabled(bool layerNormEnabled) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_LAYERNORMENABLED, static_cast(layerNormEnabled), 0); + } + void add_timeMajor(bool timeMajor) { + fbb_.AddElement(UnidirectionalSequenceLstmDescriptor::VT_TIMEMAJOR, static_cast(timeMajor), 0); + } + explicit UnidirectionalSequenceLstmDescriptorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnidirectionalSequenceLstmDescriptorBuilder &operator=(const UnidirectionalSequenceLstmDescriptorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUnidirectionalSequenceLstmDescriptor( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t activationFunc = 0, + float clippingThresCell = 0.0f, + float clippingThresProj = 0.0f, + bool cifgEnabled = true, + bool peepholeEnabled = false, + bool projectionEnabled = false, + bool layerNormEnabled = false, + bool timeMajor = false) { + UnidirectionalSequenceLstmDescriptorBuilder builder_(_fbb); + builder_.add_clippingThresProj(clippingThresProj); + builder_.add_clippingThresCell(clippingThresCell); + builder_.add_activationFunc(activationFunc); + builder_.add_timeMajor(timeMajor); + builder_.add_layerNormEnabled(layerNormEnabled); + builder_.add_projectionEnabled(projectionEnabled); + builder_.add_peepholeEnabled(peepholeEnabled); + builder_.add_cifgEnabled(cifgEnabled); + return builder_.Finish(); +} + +struct UnidirectionalSequenceLstmLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UnidirectionalSequenceLstmLayerBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BASE = 4, + VT_DESCRIPTOR = 6, + VT_INPUTPARAMS = 8 + }; + const armnnSerializer::LayerBase *base() const { + return GetPointer(VT_BASE); + } + const armnnSerializer::UnidirectionalSequenceLstmDescriptor *descriptor() const { + return GetPointer(VT_DESCRIPTOR); + } + const armnnSerializer::LstmInputParams *inputParams() const { + return GetPointer(VT_INPUTPARAMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BASE) && + verifier.VerifyTable(base()) && + VerifyOffset(verifier, VT_DESCRIPTOR) && + verifier.VerifyTable(descriptor()) && + VerifyOffset(verifier, VT_INPUTPARAMS) && + verifier.VerifyTable(inputParams()) && + verifier.EndTable(); + } +}; + +struct UnidirectionalSequenceLstmLayerBuilder { + typedef UnidirectionalSequenceLstmLayer Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_base(flatbuffers::Offset base) { + fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_BASE, base); + } + void add_descriptor(flatbuffers::Offset descriptor) { + fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_DESCRIPTOR, descriptor); + } + void add_inputParams(flatbuffers::Offset inputParams) { + fbb_.AddOffset(UnidirectionalSequenceLstmLayer::VT_INPUTPARAMS, inputParams); + } + explicit UnidirectionalSequenceLstmLayerBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnidirectionalSequenceLstmLayerBuilder &operator=(const UnidirectionalSequenceLstmLayerBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUnidirectionalSequenceLstmLayer( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset base = 0, + flatbuffers::Offset descriptor = 0, + flatbuffers::Offset inputParams = 0) { + UnidirectionalSequenceLstmLayerBuilder builder_(_fbb); + builder_.add_inputParams(inputParams); + builder_.add_descriptor(descriptor); + builder_.add_base(base); + return builder_.Finish(); +} + struct AnyLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef AnyLayerBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -9627,6 +9820,9 @@ struct AnyLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const armnnSerializer::ShapeLayer *layer_as_ShapeLayer() const { return layer_type() == armnnSerializer::Layer_ShapeLayer ? static_cast(layer()) : nullptr; } + const armnnSerializer::UnidirectionalSequenceLstmLayer *layer_as_UnidirectionalSequenceLstmLayer() const { + return layer_type() == armnnSerializer::Layer_UnidirectionalSequenceLstmLayer ? static_cast(layer()) : nullptr; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_LAYER_TYPE) && @@ -9888,6 +10084,10 @@ template<> inline const armnnSerializer::ShapeLayer *AnyLayer::layer_as inline const armnnSerializer::UnidirectionalSequenceLstmLayer *AnyLayer::layer_as() const { + return layer_as_UnidirectionalSequenceLstmLayer(); +} + struct AnyLayerBuilder { typedef AnyLayer Table; flatbuffers::FlatBufferBuilder &fbb_; @@ -10360,6 +10560,10 @@ inline bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case Layer_UnidirectionalSequenceLstmLayer: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index fd7f8dc7dc..44cd1800c4 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1648,6 +1648,123 @@ void SerializerStrategy::SerializeQuantizedLstmLayer(const armnn::IConnectableLa CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } +void SerializerStrategy::SerializeUnidirectionalSequenceLstmLayer( + const armnn::IConnectableLayer* layer, + const armnn::UnidirectionalSequenceLstmDescriptor& descriptor, + const std::vector& constants, + const char* name) +{ + IgnoreUnused(name); + + auto fbUnidirectionalSequenceLstmBaseLayer = + CreateLayerBase(layer, serializer::LayerType::LayerType_UnidirectionalSequenceLstm); + + auto fbUnidirectionalSequenceLstmDescriptor = serializer::CreateUnidirectionalSequenceLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_ActivationFunc, + descriptor.m_ClippingThresCell, + descriptor.m_ClippingThresProj, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled, + descriptor.m_TimeMajor); + + // Index for constants vector + std::size_t i = 0; + + // Get mandatory/basic input parameters + auto inputToForgetWeights = CreateConstTensorInfo(constants[i++]); //InputToForgetWeights + auto inputToCellWeights = CreateConstTensorInfo(constants[i++]); //InputToCellWeights + auto inputToOutputWeights = CreateConstTensorInfo(constants[i++]); //InputToOutputWeights + auto recurrentToForgetWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToForgetWeights + auto recurrentToCellWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToCellWeights + auto recurrentToOutputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToOutputWeights + auto forgetGateBias = CreateConstTensorInfo(constants[i++]); //ForgetGateBias + auto cellBias = CreateConstTensorInfo(constants[i++]); //CellBias + auto outputGateBias = CreateConstTensorInfo(constants[i++]); //OutputGateBias + + //Define optional parameters, these will be set depending on configuration in Lstm descriptor + flatbuffers::Offset inputToInputWeights; + flatbuffers::Offset recurrentToInputWeights; + flatbuffers::Offset cellToInputWeights; + flatbuffers::Offset inputGateBias; + flatbuffers::Offset projectionWeights; + flatbuffers::Offset projectionBias; + flatbuffers::Offset cellToForgetWeights; + flatbuffers::Offset cellToOutputWeights; + flatbuffers::Offset inputLayerNormWeights; + flatbuffers::Offset forgetLayerNormWeights; + flatbuffers::Offset cellLayerNormWeights; + flatbuffers::Offset outputLayerNormWeights; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(constants[i++]); //InputToInputWeights + recurrentToInputWeights = CreateConstTensorInfo(constants[i++]); //RecurrentToInputWeights + inputGateBias = CreateConstTensorInfo(constants[i++]); //InputGateBias + } + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + cellToInputWeights = CreateConstTensorInfo(constants[i++]); //CellToInputWeights + } + cellToForgetWeights = CreateConstTensorInfo(constants[i++]); //CellToForgetWeights + cellToOutputWeights = CreateConstTensorInfo(constants[i++]); //CellToOutputWeights + } + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(constants[i++]); //ProjectionWeights + projectionBias = CreateConstTensorInfo(constants[i++]); //ProjectionBias + } + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //InputLayerNormWeights + } + forgetLayerNormWeights = CreateConstTensorInfo(constants[i++]); //ForgetLayerNormWeights + cellLayerNormWeights = CreateConstTensorInfo(constants[i++]); //CellLayerNormWeights + outputLayerNormWeights = CreateConstTensorInfo(constants[i++]); //OutputLayerNormWeights + } + + auto fbUnidirectionalSequenceLstmParams = serializer::CreateLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); + + auto fbUnidirectionalSequenceLstmLayer = serializer::CreateUnidirectionalSequenceLstmLayer( + m_flatBufferBuilder, + fbUnidirectionalSequenceLstmBaseLayer, + fbUnidirectionalSequenceLstmDescriptor, + fbUnidirectionalSequenceLstmParams); + + CreateAnyLayer(fbUnidirectionalSequenceLstmLayer.o, serializer::Layer::Layer_UnidirectionalSequenceLstmLayer); +} + fb::Offset SerializerStrategy::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { @@ -2234,6 +2351,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer, SerializeTransposeConvolution2dLayer(layer, layerDescriptor, constants, name); break; } + case armnn::LayerType::UnidirectionalSequenceLstm : + { + const armnn::UnidirectionalSequenceLstmDescriptor& layerDescriptor = + static_cast(descriptor); + SerializeUnidirectionalSequenceLstmLayer(layer, layerDescriptor, constants, name); + break; + } default: { throw InvalidArgumentException( diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index c99e87d3e9..dead8739cc 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -348,6 +348,11 @@ private: void SerializeTransposeLayer(const armnn::IConnectableLayer* layer, const armnn::TransposeDescriptor& descriptor, const char* name = nullptr); + + void SerializeUnidirectionalSequenceLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::UnidirectionalSequenceLstmDescriptor& descriptor, + const std::vector& constants, + const char* name = nullptr); }; diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp index c2bc8737b4..bdc37877f7 100644 --- a/src/armnnSerializer/test/LstmSerializationTests.cpp +++ b/src/armnnSerializer/test/LstmSerializationTests.cpp @@ -74,7 +74,7 @@ armnn::LstmInputParams ConstantVector2LstmInputParams(const std::vector class VerifyLstmLayer : public LayerVerifierBaseWithDescriptor { @@ -99,6 +99,7 @@ public: case armnn::LayerType::Input: break; case armnn::LayerType::Output: break; case armnn::LayerType::Lstm: + case armnn::LayerType::UnidirectionalSequenceLstm: { this->VerifyNameAndConnections(layer, name); const Descriptor& internalDescriptor = static_cast(descriptor); @@ -2195,4 +2196,507 @@ TEST_CASE("SerializeDeserializeQLstmAdvanced") deserializedNetwork->ExecuteStrategy(checker); } +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjection") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 1; + const uint32_t timeSize = 2; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToForgetWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector cellToForgetWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("UnidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndProjection") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 2; + const uint32_t timeSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToInputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector inputToForgetWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector inputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector forgetGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector cellBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector outputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToInputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector cellToInputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector cellToForgetWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector projectionWeightsData = GenerateRandomData(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("unidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNorm") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_LayerNormEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 2; + const uint32_t timeSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToInputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector inputToForgetWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector inputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector forgetGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector cellBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector outputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToInputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector cellToInputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector cellToForgetWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector projectionWeightsData = GenerateRandomData(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + std::vector inputLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector forgetLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector cellLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector outLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outLayerNormWeights(tensorInfo20, forgetGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + // additional params because: despriptor.m_LayerNormEnabled = true + params.m_InputLayerNormWeights = &inputLayerNormWeights; + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outLayerNormWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("unidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectionTimeMajor") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = true; + + const uint32_t batchSize = 1; + const uint32_t timeSize = 2; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToForgetWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector cellToForgetWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("UnidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + } -- cgit v1.2.1