aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /src/backends/backendsCommon/WorkloadFactory.cpp
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp33
1 files changed, 32 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index b74b6afeb3..8ef5985fb3 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -396,6 +396,10 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
const TensorInfo* projectionBias = nullptr;
const TensorInfo* cellToForgetWeights = nullptr;
const TensorInfo* cellToOutputWeights = nullptr;
+ const TensorInfo* inputLayerNormWeights = nullptr;
+ const TensorInfo* forgetLayerNormWeights = nullptr;
+ const TensorInfo* cellLayerNormWeights = nullptr;
+ const TensorInfo* outputLayerNormWeights = nullptr;
TensorInfo optInputToInputWeights;
TensorInfo optRecurrentToInputWeights;
@@ -405,6 +409,10 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
TensorInfo optProjectionBias;
TensorInfo optCellToForgetWeights;
TensorInfo optCellToOutputWeights;
+ TensorInfo optInputLayerNormWeights;
+ TensorInfo optForgetLayerNormWeights;
+ TensorInfo optCellLayerNormWeights;
+ TensorInfo optOutputLayerNormWeights;
if(!descriptor.m_CifgEnabled)
{
@@ -449,6 +457,25 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
cellToOutputWeights = &optCellToOutputWeights;
}
+ if(descriptor.m_LayerNormEnabled)
+ {
+ optInputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
+ inputLayerNormWeights = &optInputLayerNormWeights;
+
+ optForgetLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
+ forgetLayerNormWeights = &optForgetLayerNormWeights;
+
+ optCellLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
+ cellLayerNormWeights = &optCellLayerNormWeights;
+
+ optOutputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
+ outputLayerNormWeights = &optOutputLayerNormWeights;
+ }
+
result = layerSupportObject->IsLstmSupported(
input,
outputStateIn,
@@ -475,7 +502,11 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
projectionBias,
cellToForgetWeights,
cellToOutputWeights,
- reason);
+ reason,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
break;
}
case LayerType::Maximum: