aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /src/backends/reference/RefLayerSupport.cpp
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp10
1 files changed, 9 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index b563badca5..3d260c5abd 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -861,7 +861,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo* projectionBias,
const TensorInfo* cellToForgetWeights,
const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported) const
+ Optional<std::string&> reasonIfUnsupported,
+ const TensorInfo* inputLayerNormWeights,
+ const TensorInfo* forgetLayerNormWeights,
+ const TensorInfo* cellLayerNormWeights,
+ const TensorInfo* outputLayerNormWeights) const
{
ignore_unused(descriptor);
ignore_unused(inputToForgetWeights);
@@ -881,6 +885,10 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
ignore_unused(projectionBias);
ignore_unused(cellToForgetWeights);
ignore_unused(cellToOutputWeights);
+ ignore_unused(inputLayerNormWeights);
+ ignore_unused(forgetLayerNormWeights);
+ ignore_unused(cellLayerNormWeights);
+ ignore_unused(outputLayerNormWeights);
bool supported = true;