aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp13
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
2 files changed, 7 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 40ab798ba2..5628c36884 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -529,12 +529,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
optRecurrentToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
- if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
- {
- optCellToInputWeights =
- OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
- paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
- }
optInputGateBias =
OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
paramsInfo.m_InputGateBias = &optInputGateBias;
@@ -555,6 +549,13 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
if(descriptor.m_PeepholeEnabled)
{
+ if(!descriptor.m_CifgEnabled)
+ {
+ optCellToInputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
+ dataType);
+ paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
+ }
optCellToForgetWeights =
OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 7534c8a7c9..dccfd1e75b 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -288,8 +288,6 @@ struct DummyLstmLayer
armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32));
m_Layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique<armnn::ScopedCpuTensorHandle>(
armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32));
- m_Layer->m_CifgParameters.m_CellToInputWeights = std::make_unique<armnn::ScopedCpuTensorHandle>(
- armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32));
m_Layer->m_CifgParameters.m_InputGateBias = std::make_unique<armnn::ScopedCpuTensorHandle>(
armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32));
}