diff options
author | Ferran Balaguer <ferran.balaguer@arm.com> | 2019-07-24 17:03:45 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-07-26 12:35:14 +0000 |
commit | e30c16e7759d141c7f262988b67a7ec13758e596 (patch) | |
tree | f65fda41c6eb863d8f66818c9b9d7e4c184acbc8 | |
parent | 23700bb441c064dddc5327be37dcf88b541bf652 (diff) | |
download | armnn-e30c16e7759d141c7f262988b67a7ec13758e596.tar.gz |
IVGCVSW-3566 Fix LSTM with normalization and Cifg WorkloadFactory
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I586415357d0f0d315c7174ad385167effa66b195
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index cbaae4075c..dca5778e0e 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -460,9 +460,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, if(descriptor.m_LayerNormEnabled) { - optInputLayerNormWeights = OverrideDataType( - cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + if (!descriptor.m_CifgEnabled) + { + optInputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + } optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); |