aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFerran Balaguer <ferran.balaguer@arm.com>2019-07-24 17:03:45 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-07-26 12:35:14 +0000
commite30c16e7759d141c7f262988b67a7ec13758e596 (patch)
treef65fda41c6eb863d8f66818c9b9d7e4c184acbc8
parent23700bb441c064dddc5327be37dcf88b541bf652 (diff)
downloadarmnn-e30c16e7759d141c7f262988b67a7ec13758e596.tar.gz
IVGCVSW-3566 Fix LSTM with normalization and Cifg WorkloadFactory
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com> Change-Id: I586415357d0f0d315c7174ad385167effa66b195
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index cbaae4075c..dca5778e0e 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -460,9 +460,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
if(descriptor.m_LayerNormEnabled)
{
- optInputLayerNormWeights = OverrideDataType(
- cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
- paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
+ if (!descriptor.m_CifgEnabled)
+ {
+ optInputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
+ }
optForgetLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);