diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-15 16:16:25 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-22 18:29:55 +0100 |
commit | 8ed39ae450a077c7e4d672b5f05ff1d68ee67aab (patch) | |
tree | 31a1cf006e50db54f3e7a605825c8e9e3f9d689e /src/armnn | |
parent | 15fcc7ed3163c9d4b1856955271854198c3c2696 (diff) | |
download | armnn-8ed39ae450a077c7e4d672b5f05ff1d68ee67aab.tar.gz |
MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/BackendHelper.cpp | 21 | ||||
-rw-r--r-- | src/armnn/LayersFwd.hpp | 4 | ||||
-rw-r--r-- | src/armnn/Network.cpp | 152 | ||||
-rw-r--r-- | src/armnn/Network.hpp | 4 | ||||
-rw-r--r-- | src/armnn/layers/LstmLayer.hpp | 63 | ||||
-rw-r--r-- | src/armnn/layers/LstmParameters.hpp | 76 | ||||
-rw-r--r-- | src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp | 492 | ||||
-rw-r--r-- | src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp | 65 |
8 files changed, 813 insertions, 64 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index a7bf419a7c..13bde0aafa 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -842,4 +842,25 @@ bool LayerSupportHandle::IsTransposeSupported(const TensorInfo& input, return m_LayerSupport->IsTransposeSupported(input, output, descriptor, reasonIfUnsupported.value()); } +bool LayerSupportHandle::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional<TensorInfo>& hiddenStateOutput, + const Optional<TensorInfo>& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) +{ + return m_LayerSupport->IsUnidirectionalSequenceLstmSupported(input, + outputStateIn, + cellStateIn, + output, + hiddenStateOutput, + cellStateOutput, + descriptor, + paramsInfo, + reasonIfUnsupported); +} + }
\ No newline at end of file diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index cdbcaa7e90..e3ae23cf40 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -73,6 +73,7 @@ #include "layers/SwitchLayer.hpp" #include "layers/TransposeConvolution2dLayer.hpp" #include "layers/TransposeLayer.hpp" +#include "layers/UnidirectionalSequenceLstmLayer.hpp" #include "layers/UnmapLayer.hpp" namespace armnn @@ -107,6 +108,7 @@ DECLARE_LAYER(Addition) DECLARE_LAYER(ArgMinMax) DECLARE_LAYER(BatchNormalization) DECLARE_LAYER(BatchToSpaceNd) +DECLARE_LAYER(Cast) DECLARE_LAYER(Comparison) DECLARE_LAYER(Concat) DECLARE_LAYER(Constant) @@ -168,6 +170,6 @@ DECLARE_LAYER(Subtraction) DECLARE_LAYER(Switch) DECLARE_LAYER(Transpose) DECLARE_LAYER(TransposeConvolution2d) +DECLARE_LAYER(UnidirectionalSequenceLstm) DECLARE_LAYER(Unmap) -DECLARE_LAYER(Cast) } diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index d340f021e2..83eafe7993 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -518,6 +518,14 @@ IConnectableLayer* INetwork::AddLogicalBinaryLayer(const LogicalBinaryDescriptor return pNetworkImpl->AddLogicalBinaryLayer(descriptor, name); } +IConnectableLayer* INetwork::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + return pNetworkImpl->AddUnidirectionalSequenceLstmLayer(descriptor, params, name); +} + void INetwork::Accept(ILayerVisitor& visitor) const { return pNetworkImpl->Accept(visitor); @@ -2603,11 +2611,153 @@ IConnectableLayer* NetworkImpl::AddQLstmLayer(const QLstmDescriptor& descriptor } IConnectableLayer* NetworkImpl::AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor, - const char* name) + const char* name) { return m_Graph->AddLayer<LogicalBinaryLayer>(logicalBinaryDescriptor, name); } +IConnectableLayer* NetworkImpl::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + const auto layer = m_Graph->AddLayer<UnidirectionalSequenceLstmLayer>(descriptor, name); + + //Lstm Basic Parameters + layer->m_BasicParameters.m_InputToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToForgetWeights)); + layer->m_BasicParameters.m_InputToCellWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToCellWeights)); + layer->m_BasicParameters.m_InputToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToOutputWeights)); + layer->m_BasicParameters.m_RecurrentToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToForgetWeights)); + layer->m_BasicParameters.m_RecurrentToCellWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToCellWeights)); + layer->m_BasicParameters.m_RecurrentToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToOutputWeights)); + layer->m_BasicParameters.m_ForgetGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_ForgetGateBias)); + layer->m_BasicParameters.m_CellBias = + std::make_shared<ScopedTensorHandle>(*(params.m_CellBias)); + layer->m_BasicParameters.m_OutputGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_OutputGateBias)); + + //Lstm Cifg parameters + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_RecurrentToInputWeights == nullptr) + { + throw InvalidArgumentException( + "AddUnidirectionalSequenceLstmLayer: Recurrent To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_InputGateBias == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input Gate Bias cannot be NULL " + "when CIFG is disabled."); + } + layer->m_CifgParameters.m_InputToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToInputWeights)); + layer->m_CifgParameters.m_RecurrentToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToInputWeights)); + layer->m_CifgParameters.m_InputGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_InputGateBias)); + } + + //Lstm projection parameters + if(descriptor.m_ProjectionEnabled) + { + if(params.m_ProjectionWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Projection Weights cannot be NULL " + "when projection is enabled."); + } + layer->m_ProjectionParameters.m_ProjectionWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionWeights)); + if(params.m_ProjectionBias != nullptr) + { + layer->m_ProjectionParameters.m_ProjectionBias = + std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionBias)); + } + } + + //Lstm Peephole params + if(descriptor.m_PeepholeEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_CellToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Input Weights " + "cannot be NULL when Peephole is enabled and CIFG disabled."); + } + + layer->m_PeepholeParameters.m_CellToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToInputWeights)); + } + + if(params.m_CellToForgetWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Forget Weights cannot be NULL " + "when Peephole is enabled."); + } + if(params.m_CellToOutputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Output Weights cannot be NULL " + "when Peephole is enabled."); + } + + layer->m_PeepholeParameters.m_CellToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToForgetWeights)); + layer->m_PeepholeParameters.m_CellToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToOutputWeights)); + } + + //Lstm Layer Normalization params + if(descriptor.m_LayerNormEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input layer normalization weights " + "cannot be NULL when layer normalization is enabled and CIFG disabled."); + } + layer->m_LayerNormParameters.m_InputLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputLayerNormWeights)); + } + + if(params.m_ForgetLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Forget layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_CellLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_OutputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Output layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_ForgetLayerNormWeights)); + layer->m_LayerNormParameters.m_CellLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellLayerNormWeights)); + layer->m_LayerNormParameters.m_OutputLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_OutputLayerNormWeights)); + } + return layer; +} + void NetworkImpl::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 6f9be5635a..54c3497c90 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -269,6 +269,10 @@ public: IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor, const char* name = nullptr); + IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name = nullptr); + void Accept(ILayerVisitor& visitor) const; void ExecuteStrategy(IStrategy& strategy) const; diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index f711ea7607..dc6d12a1d8 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -5,74 +5,13 @@ #pragma once #include "LayerWithParameters.hpp" +#include "LstmParameters.hpp" namespace armnn { class ScopedTensorHandle; -struct LstmOptLayerNormParameters -{ - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_InputLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_ForgetLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_CellLayerNormWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_OutputLayerNormWeights; -}; - -struct LstmOptCifgParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_InputToInputWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_RecurrentToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_InputGateBias; -}; - -struct LstmOptProjectionParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_ProjectionWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [output_size]. - std::shared_ptr<ConstTensorHandle> m_ProjectionBias; -}; - -struct LstmOptPeepholeParameters -{ - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_CellToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_CellToForgetWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_CellToOutputWeights; -}; - -struct LstmBasicParameters -{ - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_InputToForgetWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_InputToCellWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_InputToOutputWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_RecurrentToForgetWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_RecurrentToCellWeights; - /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. - std::shared_ptr<ConstTensorHandle> m_RecurrentToOutputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_ForgetGateBias; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_CellBias; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::shared_ptr<ConstTensorHandle> m_OutputGateBias; -}; - /// This layer represents a LSTM operation. class LstmLayer : public LayerWithParameters<LstmDescriptor> { diff --git a/src/armnn/layers/LstmParameters.hpp b/src/armnn/layers/LstmParameters.hpp new file mode 100644 index 0000000000..3809ea875f --- /dev/null +++ b/src/armnn/layers/LstmParameters.hpp @@ -0,0 +1,76 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +class ScopedTensorHandle; + +struct LstmOptLayerNormParameters +{ + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_InputLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_ForgetLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_CellLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_OutputLayerNormWeights; +}; + +struct LstmOptCifgParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_InputToInputWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_RecurrentToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_InputGateBias; +}; + +struct LstmOptProjectionParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_ProjectionWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [output_size]. + std::shared_ptr<ConstTensorHandle> m_ProjectionBias; +}; + +struct LstmOptPeepholeParameters +{ + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_CellToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_CellToForgetWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_CellToOutputWeights; +}; + +struct LstmBasicParameters +{ + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_InputToForgetWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_InputToCellWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_InputToOutputWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_RecurrentToForgetWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_RecurrentToCellWeights; + /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units]. + std::shared_ptr<ConstTensorHandle> m_RecurrentToOutputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_ForgetGateBias; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_CellBias; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::shared_ptr<ConstTensorHandle> m_OutputGateBias; +}; + +} // namespace diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp new file mode 100644 index 0000000000..45417069e4 --- /dev/null +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp @@ -0,0 +1,492 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "UnidirectionalSequenceLstmLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include <armnn/LstmParams.hpp> +#include <armnn/TypesUtils.hpp> +#include <backendsCommon/TensorHandle.hpp> +#include <backendsCommon/WorkloadFactory.hpp> + +namespace armnn +{ + +UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name) + : LayerWithParameters(3, 1, LayerType::UnidirectionalSequenceLstm, param, name) +{ +} + +std::unique_ptr<IWorkload> UnidirectionalSequenceLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const +{ + UnidirectionalSequenceLstmQueueDescriptor descriptor; + + // Basic parameters + descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get(); + descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get(); + descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get(); + descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get(); + descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get(); + descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get(); + descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get(); + descriptor.m_CellBias = m_BasicParameters.m_CellBias.get(); + descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get(); + + // Cifg parameters + if (!m_Param.m_CifgEnabled) + { + descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get(); + descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get(); + descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get(); + } + + // Projection parameters + if (m_Param.m_ProjectionEnabled) + { + descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get(); + descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get(); + } + + // Peephole parameters + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get(); + } + descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get(); + descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get(); + } + + // Layer normalisation parameters + if(m_Param.m_LayerNormEnabled) + { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get(); + } + descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get(); + descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get(); + descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get(); + } + + SetAdditionalInfo(descriptor); + + return factory.CreateUnidirectionalSequenceLstm(descriptor, PrepInfoAndDesc(descriptor)); +} + +UnidirectionalSequenceLstmLayer* UnidirectionalSequenceLstmLayer::Clone(Graph& graph) const +{ + auto layer = CloneBase<UnidirectionalSequenceLstmLayer>(graph, m_Param, GetName()); + + layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ? + m_BasicParameters.m_InputToForgetWeights + : nullptr; + layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ? + m_BasicParameters.m_InputToCellWeights : nullptr; + layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ? + m_BasicParameters.m_InputToOutputWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ? + m_BasicParameters.m_RecurrentToForgetWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ? + m_BasicParameters.m_RecurrentToCellWeights : nullptr; + layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ? + m_BasicParameters.m_RecurrentToOutputWeights : nullptr; + layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ? + m_BasicParameters.m_ForgetGateBias : nullptr; + layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ? + m_BasicParameters.m_CellBias : nullptr; + layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ? + m_BasicParameters.m_OutputGateBias : nullptr; + + if (!m_Param.m_CifgEnabled) + { + layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ? + m_CifgParameters.m_InputToInputWeights : nullptr; + layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ? + m_CifgParameters.m_RecurrentToInputWeights : nullptr; + layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ? + m_CifgParameters.m_InputGateBias : nullptr; + } + + if (m_Param.m_ProjectionEnabled) + { + layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ? + m_ProjectionParameters.m_ProjectionWeights : nullptr; + layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ? + m_ProjectionParameters.m_ProjectionBias : nullptr; + } + + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ? + m_PeepholeParameters.m_CellToInputWeights : nullptr; + } + layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ? + m_PeepholeParameters.m_CellToForgetWeights : nullptr; + layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ? + m_PeepholeParameters.m_CellToOutputWeights : nullptr; + } + + if (m_Param.m_LayerNormEnabled) + { + layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ? + m_LayerNormParameters.m_InputLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ? + m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ? + m_LayerNormParameters.m_CellLayerNormWeights : nullptr; + layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ? + m_LayerNormParameters.m_OutputLayerNormWeights : nullptr; + } + + return std::move(layer); +} + +std::vector<TensorShape> UnidirectionalSequenceLstmLayer::InferOutputShapes( + const std::vector<TensorShape>& inputShapes) const +{ + ARMNN_ASSERT(inputShapes.size() == 3); + + // Get input values for validation + unsigned int outputSize = inputShapes[1][1]; + + std::vector<TensorShape> outShapes; + if (m_Param.m_TimeMajor) + { + outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize})); + } + else + { + outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize})); + } + return outShapes; +} + +void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(3, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + auto inferredShapes = InferOutputShapes( { + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() + }); + + ARMNN_ASSERT(inferredShapes.size() == 1); + + // Check if the weights are nullptr + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_CellBias should not be null."); + ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_OutputGateBias should not be null."); + + if (!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not be null."); + } + else + { + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value " + "when CIFG is enabled."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value " + "when CIFG is enabled."); + ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, + "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not have a value " + "when CIFG is enabled."); + } + + if (m_Param.m_ProjectionEnabled) + { + ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_ProjectionParameters.m_ProjectionWeights " + "should not be null."); + } + + if (m_Param.m_PeepholeEnabled) + { + if (!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToInputWeights " + "should not be null " + "when Peephole is enabled and CIFG is disabled."); + } + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToForgetWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToOutputWeights " + "should not be null."); + } + + if (m_Param.m_LayerNormEnabled) + { + if(!m_Param.m_CifgEnabled) + { + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_inputLayerNormWeights " + "should not be null."); + } + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_cellLayerNormWeights " + "should not be null."); + ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr, + "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_outputLayerNormWeights " + "should not be null."); + } + + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer"); +} + +Layer::ConstantTensors UnidirectionalSequenceLstmLayer::GetConstantTensorsByRef() +{ + return {m_BasicParameters.m_InputToForgetWeights, + m_BasicParameters.m_InputToCellWeights, + m_BasicParameters.m_InputToOutputWeights, + m_BasicParameters.m_RecurrentToForgetWeights, + m_BasicParameters.m_RecurrentToCellWeights, + m_BasicParameters.m_RecurrentToOutputWeights, + m_BasicParameters.m_ForgetGateBias, + m_BasicParameters.m_CellBias, + m_BasicParameters.m_OutputGateBias, + + // Cifg parameters + m_CifgParameters.m_InputToInputWeights, + m_CifgParameters.m_RecurrentToInputWeights, + m_CifgParameters.m_InputGateBias, + + // Projection parameters + m_ProjectionParameters.m_ProjectionWeights, + m_ProjectionParameters.m_ProjectionBias, + + // Peephole parameters + m_PeepholeParameters.m_CellToInputWeights, + m_PeepholeParameters.m_CellToForgetWeights, + m_PeepholeParameters.m_CellToOutputWeights, + + // Layer normalisation parameters + m_LayerNormParameters.m_InputLayerNormWeights, + m_LayerNormParameters.m_ForgetLayerNormWeights, + m_LayerNormParameters.m_CellLayerNormWeights, + m_LayerNormParameters.m_OutputLayerNormWeights}; +} + +void UnidirectionalSequenceLstmLayer::Accept(ILayerVisitor& visitor) const +{ + IgnoreUnused(visitor); + throw armnn::Exception("UnidirectionalSequenceLstmLayer: VisitUnidirectionalSequenceLstmLayer is not implemented"); +} + +void UnidirectionalSequenceLstmLayer::ExecuteStrategy(IStrategy& strategy) const +{ + std::vector<ConstTensor> constTensors; + + LstmDescriptor descriptor = GetParameters(); + + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); + + // First add mandatory/basic parameters + if (m_BasicParameters.m_InputToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map())); + } + if (m_BasicParameters.m_InputToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map())); + } + if (m_BasicParameters.m_InputToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map())); + } + if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map())); + } + if (m_BasicParameters.m_ForgetGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map())); + } + if (m_BasicParameters.m_CellBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(), + managedCellBias.Map())); + } + if (m_BasicParameters.m_OutputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map())); + } + + // Add cifg parameters + if (!descriptor.m_CifgEnabled) + { + if (m_CifgParameters.m_InputToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map())); + } + if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map())); + } + if (m_CifgParameters.m_InputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map())); + } + } + + // Add peephole parameters + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map())); + } + } + if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map())); + } + if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map())); + } + } + + // Add projection parameters + if (descriptor.m_ProjectionEnabled) + { + if (m_ProjectionParameters.m_ProjectionWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map())); + } + if (m_ProjectionParameters.m_ProjectionBias != nullptr) + { + constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map())); + } + } + + // Add norm parameters + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map())); + } + } + if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map())); + } + if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map())); + } + if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map())); + } + } + + strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp new file mode 100644 index 0000000000..fb59f01ab6 --- /dev/null +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.hpp @@ -0,0 +1,65 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" +#include "LstmParameters.hpp" + +namespace armnn +{ + +class ScopedTensorHandle; + +/// This layer represents a LSTM operation. +class UnidirectionalSequenceLstmLayer : public LayerWithParameters<LstmDescriptor> +{ +public: + + LstmBasicParameters m_BasicParameters; + LstmOptCifgParameters m_CifgParameters; + LstmOptProjectionParameters m_ProjectionParameters; + LstmOptPeepholeParameters m_PeepholeParameters; + LstmOptLayerNormParameters m_LayerNormParameters; + + /// Makes a workload for the UnidirectionalSequence LSTM type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + UnidirectionalSequenceLstmLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref UnidirectionalSequenceLstmLayer. + /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated. + void ValidateTensorShapesFromInputs() override; + + /// By default returns inputShapes if the number of inputs are equal to number of outputs, + /// otherwise infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + + void Accept(ILayerVisitor& visitor) const override; + + void ExecuteStrategy(IStrategy& strategy) const override; + +protected: + /// Constructor to create a UnidirectionalSequenceLstmLayer. + /// @param [in] param LstmDescriptor to configure the lstm operation. + /// @param [in] name Optional name for the layer. + UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name); + + /// Default destructor + ~UnidirectionalSequenceLstmLayer() = default; + + /// Retrieve the handles to the constant values stored by the layer. + /// @return A vector of the constant tensors stored by this layer. + Layer::ConstantTensors GetConstantTensorsByRef() override; +}; + +} // namespace |