aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-03-30 15:07:45 +0100
committerJan Eilers <jan.eilers@arm.com>2020-03-31 08:46:25 +0100
commite2062cdf1eb31b87860f9889f0e799e89f0dfa30 (patch)
tree98b1cdf21856042aa24689c6385d78a1647eb2bf /src/backends/backendsCommon/WorkloadFactory.cpp
parentcedd34fa77a42fce6b832f6424eed45543fe71d4 (diff)
downloadarmnn-e2062cdf1eb31b87860f9889f0e799e89f0dfa30.tar.gz
IVGCVSW-4590 Fix Lstm layers CellToInputWeights
* CellToInputWeights were not handeled correctly * Changed CellToInputWeights from Cifg to peephole parameter * Modified exiting unit tests * Added unit test to cover new configuration * Added more descriptive error messages Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: Ied5dc1253d3df1fd1a79b887a58603d0a9c8f396
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 40ab798ba2..5628c36884 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -529,12 +529,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
optRecurrentToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
- if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
- {
- optCellToInputWeights =
- OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
- paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
- }
optInputGateBias =
OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
paramsInfo.m_InputGateBias = &optInputGateBias;
@@ -555,6 +549,13 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
if(descriptor.m_PeepholeEnabled)
{
+ if(!descriptor.m_CifgEnabled)
+ {
+ optCellToInputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
+ dataType);
+ paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
+ }
optCellToForgetWeights =
OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;