aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /include
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp3
-rw-r--r--include/armnn/ILayerSupport.hpp6
-rw-r--r--include/armnn/LstmParams.hpp8
3 files changed, 16 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 85e8b56fed..9175239aa8 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -589,6 +589,7 @@ struct LstmDescriptor
, m_CifgEnabled(true)
, m_PeepholeEnabled(false)
, m_ProjectionEnabled(false)
+ , m_LayerNormEnabled(false)
{}
/// @brief The activation function to use.
@@ -604,6 +605,8 @@ struct LstmDescriptor
bool m_PeepholeEnabled;
/// Enable/disable the projection layer.
bool m_ProjectionEnabled;
+ /// Enable/disable layer normalization
+ bool m_LayerNormEnabled;
};
/// A MeanDescriptor for the MeanLayer.
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index bf0ac90c59..635b9cc663 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -170,7 +170,11 @@ public:
const TensorInfo* projectionBias,
const TensorInfo* cellToForgetWeights,
const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
+ const TensorInfo* inputLayerNormWeights = nullptr,
+ const TensorInfo* forgetLayerNormWeights = nullptr,
+ const TensorInfo* cellLayerNormWeights = nullptr,
+ const TensorInfo* outputLayerNormWeights = nullptr) const = 0;
virtual bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp
index c4f38f0067..a7c57c78b2 100644
--- a/include/armnn/LstmParams.hpp
+++ b/include/armnn/LstmParams.hpp
@@ -29,6 +29,10 @@ struct LstmInputParams
, m_OutputGateBias(nullptr)
, m_ProjectionWeights(nullptr)
, m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
{
}
@@ -49,6 +53,10 @@ struct LstmInputParams
const ConstTensor* m_OutputGateBias;
const ConstTensor* m_ProjectionWeights;
const ConstTensor* m_ProjectionBias;
+ const ConstTensor* m_InputLayerNormWeights;
+ const ConstTensor* m_ForgetLayerNormWeights;
+ const ConstTensor* m_CellLayerNormWeights;
+ const ConstTensor* m_OutputLayerNormWeights;
};
} // namespace armnn