aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 40ab798ba2..5628c36884 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -529,12 +529,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
optRecurrentToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
- if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
- {
- optCellToInputWeights =
- OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
- paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
- }
optInputGateBias =
OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
paramsInfo.m_InputGateBias = &optInputGateBias;
@@ -555,6 +549,13 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
if(descriptor.m_PeepholeEnabled)
{
+ if(!descriptor.m_CifgEnabled)
+ {
+ optCellToInputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
+ dataType);
+ paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
+ }
optCellToForgetWeights =
OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;