From e30c16e7759d141c7f262988b67a7ec13758e596 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Wed, 24 Jul 2019 17:03:45 +0100 Subject: IVGCVSW-3566 Fix LSTM with normalization and Cifg WorkloadFactory Signed-off-by: Ferran Balaguer Change-Id: I586415357d0f0d315c7174ad385167effa66b195 --- src/backends/backendsCommon/WorkloadFactory.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index cbaae4075c..dca5778e0e 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -460,9 +460,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, if(descriptor.m_LayerNormEnabled) { - optInputLayerNormWeights = OverrideDataType( - cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + if (!descriptor.m_CifgEnabled) + { + optInputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + } optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); -- cgit v1.2.1