From 8ed39ae450a077c7e4d672b5f05ff1d68ee67aab Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 15 Jul 2021 16:16:25 +0100 Subject: MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN Signed-off-by: Narumol Prangnawarat Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb --- Android.mk | 1 + CMakeLists.txt | 2 + include/armnn/BackendHelper.hpp | 11 + include/armnn/Descriptors.hpp | 8 +- include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/INetwork.hpp | 9 + include/armnn/Types.hpp | 9 +- include/armnn/backends/ILayerSupport.hpp | 11 + src/armnn/BackendHelper.cpp | 21 + src/armnn/LayersFwd.hpp | 4 +- src/armnn/Network.cpp | 152 ++++++- src/armnn/Network.hpp | 4 + src/armnn/layers/LstmLayer.hpp | 63 +-- src/armnn/layers/LstmParameters.hpp | 76 ++++ .../layers/UnidirectionalSequenceLstmLayer.cpp | 492 +++++++++++++++++++++ .../layers/UnidirectionalSequenceLstmLayer.hpp | 65 +++ src/backends/backendsCommon/LayerSupportBase.cpp | 13 + src/backends/backendsCommon/LayerSupportBase.hpp | 11 + src/backends/backendsCommon/WorkloadData.cpp | 276 +++++++++++- src/backends/backendsCommon/WorkloadData.hpp | 52 +++ src/backends/backendsCommon/WorkloadFactory.cpp | 148 +++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 4 + .../test/IsLayerSupportedTestImpl.hpp | 53 +++ 23 files changed, 1416 insertions(+), 70 deletions(-) create mode 100644 src/armnn/layers/LstmParameters.hpp create mode 100644 src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp create mode 100644 src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp diff --git a/Android.mk b/Android.mk index 79e5623cd0..d3f1dcf75e 100644 --- a/Android.mk +++ b/Android.mk @@ -218,6 +218,7 @@ LOCAL_SRC_FILES := \ src/armnn/layers/SwitchLayer.cpp \ src/armnn/layers/TransposeConvolution2dLayer.cpp \ src/armnn/layers/TransposeLayer.cpp \ + src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp \ src/armnn/layers/UnmapLayer.cpp \ src/profiling/ActivateTimelineReportingCommandHandler.cpp \ src/profiling/BufferManager.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 13d3937689..0156a19d8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -331,6 +331,8 @@ list(APPEND armnn_sources src/armnn/layers/TransposeConvolution2dLayer.hpp src/armnn/layers/TransposeLayer.hpp src/armnn/layers/TransposeLayer.cpp + src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp + src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp src/armnn/layers/UnmapLayer.cpp src/armnn/layers/UnmapLayer.hpp src/armnn/AsyncExecutionCallback.cpp diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 093f822040..dee3b48b81 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -433,6 +433,17 @@ public: const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()); + bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()); + private: std::shared_ptr m_LayerSupport; const BackendId m_BackendId; diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 683ef7ac98..bcee902d75 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -926,6 +926,7 @@ struct LstmDescriptor : BaseDescriptor , m_PeepholeEnabled(false) , m_ProjectionEnabled(false) , m_LayerNormEnabled(false) + , m_TimeMajor(true) {} bool operator ==(const LstmDescriptor& rhs) const @@ -935,7 +936,8 @@ struct LstmDescriptor : BaseDescriptor m_ClippingThresProj == rhs.m_ClippingThresProj && m_CifgEnabled == rhs.m_CifgEnabled && m_PeepholeEnabled == rhs.m_PeepholeEnabled && - m_LayerNormEnabled == rhs.m_LayerNormEnabled; + m_LayerNormEnabled == rhs.m_LayerNormEnabled && + m_TimeMajor == rhs.m_TimeMajor; } /// @brief The activation function to use. @@ -953,8 +955,12 @@ struct LstmDescriptor : BaseDescriptor bool m_ProjectionEnabled; /// Enable/disable layer normalization bool m_LayerNormEnabled; + /// Enable/disable time major + bool m_TimeMajor; }; +using UnidirectionalSequenceLstmDescriptor = LstmDescriptor; + /// A MeanDescriptor for the MeanLayer. struct MeanDescriptor : BaseDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 9b22644c7b..3b43c42d23 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -55,5 +55,6 @@ using LogSoftmaxDescriptor = SoftmaxDescriptor; /// MergerDescriptor is deprecated, use ConcatDescriptor instead using MergerDescriptor = OriginsDescriptor; using SplitterDescriptor = ViewsDescriptor; +using UnidirectionalSequenceLstmDescriptor = LstmDescriptor; } // namespace armnn diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index b40db62a59..865d1291a9 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -691,6 +691,15 @@ public: IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& descriptor, const char* name = nullptr); + /// Add a UnidirectionalSequenceLstm layer to the network + /// @param descriptor - Parameters for the UnidirectionalSequenceLstm operation + /// @param params - Weights and biases for the UnidirectionalSequenceLstm + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer. + IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name = nullptr); + void Accept(ILayerVisitor& visitor) const; void ExecuteStrategy(IStrategy& strategy) const; diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index e7c17608ca..056aa83d2f 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -333,7 +333,6 @@ using InferenceTimingPair = std::pair; X(ArgMinMax) \ X(BatchNormalization) \ X(BatchToSpaceNd) \ - X(Cast) \ X(Comparison) \ X(Concat) \ X(Constant) \ @@ -382,7 +381,6 @@ using InferenceTimingPair = std::pair; X(Rank) \ X(Resize) \ X(Reduce) \ - X(Shape) \ X(Slice) \ X(Softmax) \ X(SpaceToBatchNd) \ @@ -396,6 +394,11 @@ using InferenceTimingPair = std::pair; X(Transpose) \ X(TransposeConvolution2d) \ X(Unmap) \ + X(Cast) \ + X(Shape) \ + X(UnidirectionalSequenceLstm) \ + +// New layers should be added at last to minimize instability. /// When adding a new layer, adapt also the LastLayer enum value in the /// enum class LayerType below @@ -405,7 +408,7 @@ enum class LayerType LIST_OF_LAYER_TYPE #undef X FirstLayer = Activation, - LastLayer = Unmap + LastLayer = UnidirectionalSequenceLstm }; const char* GetLayerTypeAsCString(LayerType type); diff --git a/include/armnn/backends/ILayerSupport.hpp b/include/armnn/backends/ILayerSupport.hpp index 462668d738..7ba565a138 100644 --- a/include/armnn/backends/ILayerSupport.hpp +++ b/include/armnn/backends/ILayerSupport.hpp @@ -424,6 +424,17 @@ public: const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + }; // class ILayerSupport using ILayerSupportSharedPtr = std::shared_ptr; diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index a7bf419a7c..13bde0aafa 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -842,4 +842,25 @@ bool LayerSupportHandle::IsTransposeSupported(const TensorInfo& input, return m_LayerSupport->IsTransposeSupported(input, output, descriptor, reasonIfUnsupported.value()); } +bool LayerSupportHandle::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) +{ + return m_LayerSupport->IsUnidirectionalSequenceLstmSupported(input, + outputStateIn, + cellStateIn, + output, + hiddenStateOutput, + cellStateOutput, + descriptor, + paramsInfo, + reasonIfUnsupported); +} + } \ No newline at end of file diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index cdbcaa7e90..e3ae23cf40 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -73,6 +73,7 @@ #include "layers/SwitchLayer.hpp" #include "layers/TransposeConvolution2dLayer.hpp" #include "layers/TransposeLayer.hpp" +#include "layers/UnidirectionalSequenceLstmLayer.hpp" #include "layers/UnmapLayer.hpp" namespace armnn @@ -107,6 +108,7 @@ DECLARE_LAYER(Addition) DECLARE_LAYER(ArgMinMax) DECLARE_LAYER(BatchNormalization) DECLARE_LAYER(BatchToSpaceNd) +DECLARE_LAYER(Cast) DECLARE_LAYER(Comparison) DECLARE_LAYER(Concat) DECLARE_LAYER(Constant) @@ -168,6 +170,6 @@ DECLARE_LAYER(Subtraction) DECLARE_LAYER(Switch) DECLARE_LAYER(Transpose) DECLARE_LAYER(TransposeConvolution2d) +DECLARE_LAYER(UnidirectionalSequenceLstm) DECLARE_LAYER(Unmap) -DECLARE_LAYER(Cast) } diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index d340f021e2..83eafe7993 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -518,6 +518,14 @@ IConnectableLayer* INetwork::AddLogicalBinaryLayer(const LogicalBinaryDescriptor return pNetworkImpl->AddLogicalBinaryLayer(descriptor, name); } +IConnectableLayer* INetwork::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + return pNetworkImpl->AddUnidirectionalSequenceLstmLayer(descriptor, params, name); +} + void INetwork::Accept(ILayerVisitor& visitor) const { return pNetworkImpl->Accept(visitor); @@ -2603,11 +2611,153 @@ IConnectableLayer* NetworkImpl::AddQLstmLayer(const QLstmDescriptor& descriptor } IConnectableLayer* NetworkImpl::AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor, - const char* name) + const char* name) { return m_Graph->AddLayer(logicalBinaryDescriptor, name); } +IConnectableLayer* NetworkImpl::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + const auto layer = m_Graph->AddLayer(descriptor, name); + + //Lstm Basic Parameters + layer->m_BasicParameters.m_InputToForgetWeights = + std::make_shared(*(params.m_InputToForgetWeights)); + layer->m_BasicParameters.m_InputToCellWeights = + std::make_shared(*(params.m_InputToCellWeights)); + layer->m_BasicParameters.m_InputToOutputWeights = + std::make_shared(*(params.m_InputToOutputWeights)); + layer->m_BasicParameters.m_RecurrentToForgetWeights = + std::make_shared(*(params.m_RecurrentToForgetWeights)); + layer->m_BasicParameters.m_RecurrentToCellWeights = + std::make_shared(*(params.m_RecurrentToCellWeights)); + layer->m_BasicParameters.m_RecurrentToOutputWeights = + std::make_shared(*(params.m_RecurrentToOutputWeights)); + layer->m_BasicParameters.m_ForgetGateBias = + std::make_shared(*(params.m_ForgetGateBias)); + layer->m_BasicParameters.m_CellBias = + std::make_shared(*(params.m_CellBias)); + layer->m_BasicParameters.m_OutputGateBias = + std::make_shared(*(params.m_OutputGateBias)); + + //Lstm Cifg parameters + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_RecurrentToInputWeights == nullptr) + { + throw InvalidArgumentException( + "AddUnidirectionalSequenceLstmLayer: Recurrent To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_InputGateBias == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input Gate Bias cannot be NULL " + "when CIFG is disabled."); + } + layer->m_CifgParameters.m_InputToInputWeights = + std::make_shared(*(params.m_InputToInputWeights)); + layer->m_CifgParameters.m_RecurrentToInputWeights = + std::make_shared(*(params.m_RecurrentToInputWeights)); + layer->m_CifgParameters.m_InputGateBias = + std::make_shared(*(params.m_InputGateBias)); + } + + //Lstm projection parameters + if(descriptor.m_ProjectionEnabled) + { + if(params.m_ProjectionWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Projection Weights cannot be NULL " + "when projection is enabled."); + } + layer->m_ProjectionParameters.m_ProjectionWeights = + std::make_shared(*(params.m_ProjectionWeights)); + if(params.m_ProjectionBias != nullptr) + { + layer->m_ProjectionParameters.m_ProjectionBias = + std::make_shared(*(params.m_ProjectionBias)); + } + } + + //Lstm Peephole params + if(descriptor.m_PeepholeEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_CellToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Input Weights " + "cannot be NULL when Peephole is enabled and CIFG disabled."); + } + + layer->m_PeepholeParameters.m_CellToInputWeights = + std::make_shared(*(params.m_CellToInputWeights)); + } + + if(params.m_CellToForgetWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Forget Weights cannot be NULL " + "when Peephole is enabled."); + } + if(params.m_CellToOutputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Output Weights cannot be NULL " + "when Peephole is enabled."); + } + + layer->m_PeepholeParameters.m_CellToForgetWeights = + std::make_shared(*(params.m_CellToForgetWeights)); + layer->m_PeepholeParameters.m_CellToOutputWeights = + std::make_shared(*(params.m_CellToOutputWeights)); + } + + //Lstm Layer Normalization params + if(descriptor.m_LayerNormEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input layer normalization weights " + "cannot be NULL when layer normalization is enabled and CIFG disabled."); + } + layer->m_LayerNormParameters.m_InputLayerNormWeights = + std::make_shared(*(params.m_InputLayerNormWeights)); + } + + if(params.m_ForgetLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Forget layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_CellLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_OutputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Output layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = + std::make_shared(*(params.m_ForgetLayerNormWeights)); + layer->m_LayerNormParameters.m_CellLayerNormWeights = + std::make_shared(*(params.m_CellLayerNormWeights)); + layer->m_LayerNormParameters.m_OutputLayerNormWeights = + std::make_shared(*(params.m_OutputLayerNormWeights)); + } + return layer; +} + void NetworkImpl::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 6f9be5635a..54c3497c90 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -269,6 +269,10 @@ public: IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor, const char* name = nullptr); + IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name = nullptr); + void Accept(ILayerVisitor& visitor) const; void ExecuteStrategy(IStrategy& strategy) const; diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index f711ea7607..dc6d12a1d8 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -5,74 +5,13 @@ #pragma once #include "LayerWithParameters.hpp" +#include "LstmParameters.hpp" namespace armnn { class ScopedTensorHandle; -struct LstmOptLayerNormParameters -{ - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_InputLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_ForgetLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_CellLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_OutputLayerNormWeights; -}; - -struct LstmOptCifgParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr m_InputToInputWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr m_RecurrentToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_InputGateBias; -}; - -struct LstmOptProjectionParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr m_ProjectionWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [output_size]. - std::shared_ptr m_ProjectionBias; -}; - -struct LstmOptPeepholeParameters -{ - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_CellToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_CellToForgetWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_CellToOutputWeights; -}; - -struct LstmBasicParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr m_InputToForgetWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr m_InputToCellWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr m_InputToOutputWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr m_RecurrentToForgetWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr m_RecurrentToCellWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr m_RecurrentToOutputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_ForgetGateBias; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_CellBias; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr m_OutputGateBias; -}; - /// This layer represents a LSTM operation. class LstmLayer : public LayerWithParameters { diff --git a/src/armnn/layers/LstmParameters.hpp b/src/armnn/layers/LstmParameters.hpp new file mode 100644 index 0000000000..3809ea875f --- /dev/null +++ b/src/armnn/layers/LstmParameters.hpp @@ -0,0 +1,76 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +class ScopedTensorHandle; + +struct LstmOptLayerNormParameters +{ + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_InputLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_ForgetLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_CellLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_OutputLayerNormWeights; +}; + +struct LstmOptCifgParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr m_InputToInputWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr m_RecurrentToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_InputGateBias; +}; + +struct LstmOptProjectionParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr m_ProjectionWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [output_size]. + std::shared_ptr m_ProjectionBias; +}; + +struct LstmOptPeepholeParameters +{ + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_CellToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_CellToForgetWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_CellToOutputWeights; +}; + +struct LstmBasicParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr m_InputToForgetWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr m_InputToCellWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr m_InputToOutputWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr m_RecurrentToForgetWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr m_RecurrentToCellWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr m_RecurrentToOutputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_ForgetGateBias; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_CellBias; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr m_OutputGateBias; +}; + +} // namespace diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp new file mode 100644 index 0000000000..45417069e4 --- /dev/null +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp @@ -0,0 +1,492 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "UnidirectionalSequenceLstmLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include +#include +#include +#include + +namespace armnn +{ + +UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name) + : LayerWithParameters(3, 1, LayerType::UnidirectionalSequenceLstm, param, name) +{ +} + +std::unique_ptr UnidirectionalSequenceLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const +{ + UnidirectionalSequenceLstmQueueDescriptor descriptor; + + // Basic parameters + descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get(); + descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get(); + descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get(); + descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get(); + descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get(); + descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get(); + descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get(); + descriptor.m_CellBias = m_BasicParameters.m_CellBias.get(); + descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get(); + + // Cifg parameters + if (!m_Param.m_CifgEnabled) + { + descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get(); + descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get(); + descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get(); + } + + // Projection parameters + if (m_Param.m_ProjectionEnabled) + { + descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get(); + descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get(); + } + + // Peephole parameters + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get(); + } + descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get(); + descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get(); + } + + // Layer normalisation parameters + if(m_Param.m_LayerNormEnabled) + { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get(); + } + descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get(); + descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get(); + descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get(); + } + + SetAdditionalInfo(descriptor); + + return factory.CreateUnidirectionalSequenceLstm(descriptor, PrepInfoAndDesc(descriptor)); +} + +UnidirectionalSequenceLstmLayer* UnidirectionalSequenceLstmLayer::Clone(Graph& graph) const +{ + auto layer = CloneBase(graph, m_Param, GetName()); + + layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ? + m_BasicParameters.m_InputToForgetWeights + : nullptr; + layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ? + m_BasicParameters.m_InputToCellWeights : nullptr; + layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ? + m_BasicParameters.m_InputToOutputWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ? + m_BasicParameters.m_RecurrentToForgetWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ? + m_BasicParameters.m_RecurrentToCellWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ? + m_BasicParameters.m_RecurrentToOutputWeights : nullptr; + layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ? + m_BasicParameters.m_ForgetGateBias : nullptr; + layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ? + m_BasicParameters.m_CellBias : nullptr; + layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ? + m_BasicParameters.m_OutputGateBias : nullptr; + + if (!m_Param.m_CifgEnabled) + { + layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ? + m_CifgParameters.m_InputToInputWeights : nullptr; + layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ? + m_CifgParameters.m_RecurrentToInputWeights : nullptr; + layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ? + m_CifgParameters.m_InputGateBias : nullptr; + } + + if (m_Param.m_ProjectionEnabled) + { + layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ? + m_ProjectionParameters.m_ProjectionWeights : nullptr; + layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ? + m_ProjectionParameters.m_ProjectionBias : nullptr; + } + + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ? + m_PeepholeParameters.m_CellToInputWeights : nullptr; + } + layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ? + m_PeepholeParameters.m_CellToForgetWeights : nullptr; + layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ? + m_PeepholeParameters.m_CellToOutputWeights : nullptr; + } + + if (m_Param.m_LayerNormEnabled) + { + layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ? + m_LayerNormParameters.m_InputLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ? + m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ? + m_LayerNormParameters.m_CellLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ? + m_LayerNormParameters.m_OutputLayerNormWeights : nullptr; + } + + return std::move(layer); +} + +std::vector UnidirectionalSequenceLstmLayer::InferOutputShapes( + const std::vector& inputShapes) const +{ + ARMNN_ASSERT(inputShapes.size() == 3); + + // Get input values for validation + unsigned int outputSize = inputShapes[1][1]; + + std::vector outShapes; + if (m_Param.m_TimeMajor) + { + outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize})); + } + else + { + outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize})); + } + return outShapes; +} + +void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(3, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + auto inferredShapes = InferOutputShapes( { + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() + }); + + ARMNN_ASSERT(inferredShapes.size() == 1); + + // Check if the weights are nullptr + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_CellBias should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_OutputGateBias should not be null."); + + if (!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not be null."); + } + else + { + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value " + "when CIFG is enabled."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value " + "when CIFG is enabled."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not have a value " + "when CIFG is enabled."); + } + + if (m_Param.m_ProjectionEnabled) + { + ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_ProjectionParameters.m_ProjectionWeights " + "should not be null."); + } + + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToInputWeights " + "should not be null " + "when Peephole is enabled and CIFG is disabled."); + } + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToForgetWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToOutputWeights " + "should not be null."); + } + + if (m_Param.m_LayerNormEnabled) + { + if(!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_inputLayerNormWeights " + "should not be null."); + } + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_cellLayerNormWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_outputLayerNormWeights " + "should not be null."); + } + + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer"); +} + +Layer::ConstantTensors UnidirectionalSequenceLstmLayer::GetConstantTensorsByRef() +{ + return {m_BasicParameters.m_InputToForgetWeights, + m_BasicParameters.m_InputToCellWeights, + m_BasicParameters.m_InputToOutputWeights, + m_BasicParameters.m_RecurrentToForgetWeights, + m_BasicParameters.m_RecurrentToCellWeights, + m_BasicParameters.m_RecurrentToOutputWeights, + m_BasicParameters.m_ForgetGateBias, + m_BasicParameters.m_CellBias, + m_BasicParameters.m_OutputGateBias, + + // Cifg parameters + m_CifgParameters.m_InputToInputWeights, + m_CifgParameters.m_RecurrentToInputWeights, + m_CifgParameters.m_InputGateBias, + + // Projection parameters + m_ProjectionParameters.m_ProjectionWeights, + m_ProjectionParameters.m_ProjectionBias, + + // Peephole parameters + m_PeepholeParameters.m_CellToInputWeights, + m_PeepholeParameters.m_CellToForgetWeights, + m_PeepholeParameters.m_CellToOutputWeights, + + // Layer normalisation parameters + m_LayerNormParameters.m_InputLayerNormWeights, + m_LayerNormParameters.m_ForgetLayerNormWeights, + m_LayerNormParameters.m_CellLayerNormWeights, + m_LayerNormParameters.m_OutputLayerNormWeights}; +} + +void UnidirectionalSequenceLstmLayer::Accept(ILayerVisitor& visitor) const +{ + IgnoreUnused(visitor); + throw armnn::Exception("UnidirectionalSequenceLstmLayer: VisitUnidirectionalSequenceLstmLayer is not implemented"); +} + +void UnidirectionalSequenceLstmLayer::ExecuteStrategy(IStrategy& strategy) const +{ + std::vector constTensors; + + LstmDescriptor descriptor = GetParameters(); + + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); + + // First add mandatory/basic parameters + if (m_BasicParameters.m_InputToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map())); + } + if (m_BasicParameters.m_InputToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map())); + } + if (m_BasicParameters.m_InputToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map())); + } + if (m_BasicParameters.m_ForgetGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map())); + } + if (m_BasicParameters.m_CellBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(), + managedCellBias.Map())); + } + if (m_BasicParameters.m_OutputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map())); + } + + // Add cifg parameters + if (!descriptor.m_CifgEnabled) + { + if (m_CifgParameters.m_InputToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map())); + } + if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map())); + } + if (m_CifgParameters.m_InputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map())); + } + } + + // Add peephole parameters + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map())); + } + } + if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map())); + } + if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map())); + } + } + + // Add projection parameters + if (descriptor.m_ProjectionEnabled) + { + if (m_ProjectionParameters.m_ProjectionWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map())); + } + if (m_ProjectionParameters.m_ProjectionBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map())); + } + } + + // Add norm parameters + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map())); + } + } + if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map())); + } + if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map())); + } + if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map())); + } + } + + strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp new file mode 100644 index 0000000000..fb59f01ab6 --- /dev/null +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp @@ -0,0 +1,65 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" +#include "LstmParameters.hpp" + +namespace armnn +{ + +class ScopedTensorHandle; + +/// This layer represents a LSTM operation. +class UnidirectionalSequenceLstmLayer : public LayerWithParameters +{ +public: + + LstmBasicParameters m_BasicParameters; + LstmOptCifgParameters m_CifgParameters; + LstmOptProjectionParameters m_ProjectionParameters; + LstmOptPeepholeParameters m_PeepholeParameters; + LstmOptLayerNormParameters m_LayerNormParameters; + + /// Makes a workload for the UnidirectionalSequence LSTM type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + UnidirectionalSequenceLstmLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref UnidirectionalSequenceLstmLayer. + /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated. + void ValidateTensorShapesFromInputs() override; + + /// By default returns inputShapes if the number of inputs are equal to number of outputs, + /// otherwise infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector InferOutputShapes(const std::vector& inputShapes) const override; + + void Accept(ILayerVisitor& visitor) const override; + + void ExecuteStrategy(IStrategy& strategy) const override; + +protected: + /// Constructor to create a UnidirectionalSequenceLstmLayer. + /// @param [in] param LstmDescriptor to configure the lstm operation. + /// @param [in] name Optional name for the layer. + UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name); + + /// Default destructor + ~UnidirectionalSequenceLstmLayer() = default; + + /// Retrieve the handles to the constant values stored by the layer. + /// @return A vector of the constant tensors stored by this layer. + Layer::ConstantTensors GetConstantTensorsByRef() override; +}; + +} // namespace diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 8a24e1161b..138d45367e 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -678,4 +678,17 @@ bool LayerSupportBase::IsTransposeSupported(const TensorInfo&, // input return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsUnidirectionalSequenceLstmSupported(const TensorInfo&, // input + const TensorInfo&, // outputStateIn + const TensorInfo&, // cellStateIn + const TensorInfo&, // output + const Optional&, // hiddenStateOut + const Optional&, // cellStateOut + const LstmDescriptor&, // descriptor + const LstmInputParamsInfo&, // paramsInfo + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + } // namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 0277a782a1..533a2c6bdd 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -417,6 +417,17 @@ public: const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + }; } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 8c78136185..3fe0823b03 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1959,7 +1959,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); } - // Inferring batch size, number of outputs and number of cells from the inputs. const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0]; @@ -1991,7 +1990,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output), descriptorName + " output_3"); - // check that dimensions of inputs/outputs and QueueDescriptor data match with each other if ( m_InputToInputWeights ) { @@ -3741,4 +3739,278 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } +void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm + + const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"}; + + // check dimensions of all inputs and outputs + if (workloadInfo.m_InputTensorInfos.size() != 3) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of inputs."); + } + if (workloadInfo.m_OutputTensorInfos.size() != 1) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of outputs."); + } + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QAsymmS8 + }; + + // check for supported type of one input and match them with all the other input and output + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); + + // type matches all other inputs + for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[i], + descriptorName, + "input_0", + "input_" + std::to_string(i)); + } + // type matches all other outputs + for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[i], + "LstmQueueDescriptor", + "input_0", + "output_" + std::to_string(i)); + } + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + if (m_Parameters.m_ClippingThresCell < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid"); + } + if (m_Parameters.m_ClippingThresProj < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); + } + + unsigned int batchIndx = 0; + unsigned int inputIndx = 1; + uint32_t timeStep = 1; + unsigned int timeIndx = 1; + inputIndx = 2; + if (m_Parameters.m_TimeMajor) + { + batchIndx = 1; + timeIndx = 0; + + } + timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx]; + + // Inferring batch size, number of outputs and number of cells from the inputs. + const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx]; + const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx]; + ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights"); + const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0]; + ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights"); + const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1]; + + // input tensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input), + descriptorName + " input_0"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output), + descriptorName + " input_1"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell), + descriptorName + " input_2"); + + // outputTensor + ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output), + descriptorName + " output_0"); + + // check that dimensions of inputs/outputs and QueueDescriptor data match with each other + if ( m_InputToInputWeights ) + { + ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputLayerNormWeights"); + } + + ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights"); + ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights"); + ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToCellWeights"); + + if ( m_RecurrentToInputWeights ) + { + ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToInputWeights"); + } + + ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToCellWeights"); + + // Make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly. + bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + m_Parameters.m_CifgEnabled)); + if (!cifg_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and " + "RecurrentToInputWeights must either both be present (regular LSTM) " + "or both not present (CIFG-LSTM). In addition CifgEnable must be set " + "accordingly."); + } + + if ( m_CellToInputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1, + n_cell, "CellToInputWeights"); + } + if ( m_CellToForgetWeights ) + { + ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1, + n_cell, "CellToForgetWeights"); + } + if ( m_CellToOutputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1, + n_cell, "CellToOutputWeights"); + } + + // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly. + bool peephole_weights_all_or_none = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || ( !m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + if (!peephole_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters."); + } + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + if (m_Parameters.m_CifgEnabled) + { + if (m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled."); + } + } + else + { + if (!m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias " + "must be present."); + } + ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1, + n_cell, "InputGateBias"); + } + + ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias"); + ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias"); + + ValidatePointer(m_CellBias, "Null pointer check", "CellBias"); + ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias"); + + ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias"); + ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias"); + + if (m_ProjectionWeights) + { + ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2, + (n_cell * n_output), "ProjectionWeights"); + } + if (m_ProjectionBias) + { + ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias"); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias && + !m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && !m_ProjectionBias && + m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && m_ProjectionBias && + m_Parameters.m_ProjectionEnabled)); + if (!projecton_tensors_consistent) + { + throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent."); + } + + // The four layer normalization weights either all have values or none of them have values. Additionally, if + // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights + // either all have values or none of them have values. Layer normalization is used when the values of all the + // layer normalization weights are present + if (m_InputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights"); + } + if (m_ForgetLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + } + if (m_CellLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + if (m_OutputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + } + + if (m_Parameters.m_LayerNormEnabled) + { + if (!m_Parameters.m_CifgEnabled) + { + if (!m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is " + "disabled but InputLayerNormWeights are not present"); + } + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), + 1, n_cell, "InputLayerNormWeights"); + } + else if (m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is " + "enabled"); + } + + ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled", + "ForgetLayerNormWeights"); + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + + ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled", + "OutputLayerNormWeights"); + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + + ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled", + "CellLayerNormWeights"); + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer " + "normalisation weights are present."); + } +} + + } // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 36653bdc0d..78da00be5d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -695,4 +695,56 @@ struct ShapeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters +{ + UnidirectionalSequenceLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index dc70e6a9c2..1c18551679 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -1277,6 +1277,147 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, reason); break; } + case LayerType::UnidirectionalSequenceLstm: + { + auto cLayer = PolymorphicDowncast(&layer); + const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters(); + + // All inputs. + const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), + dataType); + // Outputs + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + + // Basic parameters + const TensorInfo& inputToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& forgetGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType); + const TensorInfo& cellBias + = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType); + const TensorInfo& outputGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); + + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters + TensorInfo optInputToInputWeights; + TensorInfo optRecurrentToInputWeights; + TensorInfo optCellToInputWeights; + TensorInfo optInputGateBias; + TensorInfo optProjectionWeights; + TensorInfo optProjectionBias; + TensorInfo optCellToForgetWeights; + TensorInfo optCellToOutputWeights; + TensorInfo optInputLayerNormWeights; + TensorInfo optForgetLayerNormWeights; + TensorInfo optCellLayerNormWeights; + TensorInfo optOutputLayerNormWeights; + + if(!descriptor.m_CifgEnabled) + { + optInputToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; + + optRecurrentToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; + optInputGateBias = + OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); + paramsInfo.m_InputGateBias = &optInputGateBias; + } + + if(descriptor.m_ProjectionEnabled) + { + optProjectionWeights = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); + paramsInfo.m_ProjectionWeights = &optProjectionWeights; + if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) + { + optProjectionBias = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); + paramsInfo.m_ProjectionBias = &optProjectionBias; + } + } + + if(descriptor.m_PeepholeEnabled) + { + if(!descriptor.m_CifgEnabled) + { + optCellToInputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + dataType); + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; + } + optCellToForgetWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; + optCellToOutputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; + } + + if(descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + optInputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + } + + optForgetLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; + + optCellLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; + + optOutputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; + } + + Optional hiddenStateOut; + Optional cellStateOut; + + result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input, + outputStateIn, + cellStateIn, + output, + hiddenStateOut, + cellStateOut, + descriptor, + paramsInfo, + reason); + break; + } default: { ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer."); @@ -1759,4 +1900,11 @@ std::unique_ptr IWorkloadFactory::CreateTransposeConvolution2d( return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/, + const WorkloadInfo& /*info*/) const +{ + return std::unique_ptr(); +} + } // namepsace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 1987b9b664..efb8d99fa0 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -289,6 +289,10 @@ public: const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + private: static bool IsLayerConfigurationSupported(const BackendId& backendId, const IConnectableLayer& connectableLayer, diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index ddd6eacb6d..21b33d297b 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -342,6 +342,56 @@ struct DummyLayer { }; +template +struct DummyUnidirectionalSequenceLstmLayer +{ + DummyUnidirectionalSequenceLstmLayer() + { + typename UnidirectionalSequenceLstmLayerType::DescriptorType desc; + desc.m_CifgEnabled = false; + + m_Layer = dummyGraph.AddLayer(desc, ""); + m_Layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_InputToCellWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToForgetWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToCellWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToOutputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_ForgetGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_CellBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_OutputGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + + m_Layer->m_CifgParameters.m_InputToInputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_CifgParameters.m_InputGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + } + + ~DummyUnidirectionalSequenceLstmLayer() + { + dummyGraph.EraseLayer(m_Layer); + } + + armnn::UnidirectionalSequenceLstmLayer* m_Layer; +}; + +template<> +struct DummyLayer + : public DummyUnidirectionalSequenceLstmLayer +{ +}; + template<> struct DummyLayer { @@ -651,6 +701,7 @@ DECLARE_LAYER_POLICY_2_PARAM(Pooling2d) DECLARE_LAYER_POLICY_2_PARAM(PreCompiled) DECLARE_LAYER_POLICY_1_PARAM(Prelu) + DECLARE_LAYER_POLICY_2_PARAM(QLstm) DECLARE_LAYER_POLICY_1_PARAM(QuantizedLstm) @@ -691,6 +742,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Transpose) DECLARE_LAYER_POLICY_2_PARAM(TransposeConvolution2d) +DECLARE_LAYER_POLICY_2_PARAM(UnidirectionalSequenceLstm) + DECLARE_LAYER_POLICY_MAP_PARAM(Unmap, void) -- cgit v1.2.1