diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index cbaae4075c..dca5778e0e 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -460,9 +460,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, if(descriptor.m_LayerNormEnabled) { - optInputLayerNormWeights = OverrideDataType( - cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + if (!descriptor.m_CifgEnabled) + { + optInputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + } optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); |