From d01a83c8de77c44a938a618918d17385da3baa88 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Wed, 3 Jul 2019 18:20:40 +0100 Subject: IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported !android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers --- include/armnn/ILayerSupport.hpp | 25 +------ include/armnn/LayerSupport.hpp | 11 +-- include/armnn/LstmParams.hpp | 145 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 31 deletions(-) (limited to 'include') diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 58722fe1a0..53dd29d87e 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -153,28 +154,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const = 0; + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; virtual bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index 35336ed7da..65f9d089ba 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -9,6 +9,7 @@ #include #include #include +#include "LstmParams.hpp" namespace armnn { @@ -178,15 +179,7 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported = nullptr, + const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); /// Deprecated in favor of IBackend and ILayerSupport interfaces diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp index a7c57c78b2..0c8e66dfde 100644 --- a/include/armnn/LstmParams.hpp +++ b/include/armnn/LstmParams.hpp @@ -5,6 +5,7 @@ #pragma once #include "TensorFwd.hpp" +#include "Exceptions.hpp" namespace armnn { @@ -59,5 +60,149 @@ struct LstmInputParams const ConstTensor* m_OutputLayerNormWeights; }; +struct LstmInputParamsInfo +{ + LstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + const TensorInfo* m_CellToInputWeights; + const TensorInfo* m_CellToForgetWeights; + const TensorInfo* m_CellToOutputWeights; + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + const TensorInfo* m_ProjectionWeights; + const TensorInfo* m_ProjectionBias; + const TensorInfo* m_InputLayerNormWeights; + const TensorInfo* m_ForgetLayerNormWeights; + const TensorInfo* m_CellLayerNormWeights; + const TensorInfo* m_OutputLayerNormWeights; + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + const TensorInfo& get_CellToInputWeights() const + { + return deref(m_CellToInputWeights); + } + const TensorInfo& get_CellToForgetWeights() const + { + return deref(m_CellToForgetWeights); + } + const TensorInfo& get_CellToOutputWeights() const + { + return deref(m_CellToOutputWeights); + } + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } + const TensorInfo& get_ProjectionWeights() const + { + return deref(m_ProjectionWeights); + } + const TensorInfo& get_ProjectionBias() const + { + return deref(m_ProjectionBias); + } + const TensorInfo& get_InputLayerNormWeights() const + { + return deref(m_InputLayerNormWeights); + } + const TensorInfo& get_ForgetLayerNormWeights() const + { + return deref(m_ForgetLayerNormWeights); + } + const TensorInfo& get_CellLayerNormWeights() const + { + return deref(m_CellLayerNormWeights); + } + const TensorInfo& get_OutputLayerNormWeights() const + { + return deref(m_OutputLayerNormWeights); + } +}; + } // namespace armnn -- cgit v1.2.1