aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index cbaae4075c..dca5778e0e 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -460,9 +460,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
if(descriptor.m_LayerNormEnabled)
{
- optInputLayerNormWeights = OverrideDataType(
- cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
- paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
+ if (!descriptor.m_CifgEnabled)
+ {
+ optInputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
+ }
optForgetLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);