aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/LstmLayer.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-03-30 15:07:45 +0100
committerJan Eilers <jan.eilers@arm.com>2020-03-31 08:46:25 +0100
commite2062cdf1eb31b87860f9889f0e799e89f0dfa30 (patch)
tree98b1cdf21856042aa24689c6385d78a1647eb2bf /src/armnn/layers/LstmLayer.cpp
parentcedd34fa77a42fce6b832f6424eed45543fe71d4 (diff)
downloadarmnn-e2062cdf1eb31b87860f9889f0e799e89f0dfa30.tar.gz
IVGCVSW-4590 Fix Lstm layers CellToInputWeights
* CellToInputWeights were not handeled correctly * Changed CellToInputWeights from Cifg to peephole parameter * Modified exiting unit tests * Added unit test to cover new configuration * Added more descriptive error messages Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: Ied5dc1253d3df1fd1a79b887a58603d0a9c8f396
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r--src/armnn/layers/LstmLayer.cpp28
1 files changed, 19 insertions, 9 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index 581ba45c5f..1d945690d5 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -39,7 +39,6 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& fac
{
descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
- descriptor.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights.get();
descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
}
@@ -53,6 +52,10 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& fac
// Peephole parameters
if (m_Param.m_PeepholeEnabled)
{
+ if (!m_Param.m_CifgEnabled)
+ {
+ descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
+ }
descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
}
@@ -102,8 +105,6 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const
std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputToInputWeights) : nullptr;
layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_RecurrentToInputWeights) : nullptr;
- layer->m_CifgParameters.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights ?
- std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_CellToInputWeights) : nullptr;
layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputGateBias) : nullptr;
}
@@ -118,6 +119,11 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const
if (m_Param.m_PeepholeEnabled)
{
+ if (!m_Param.m_CifgEnabled)
+ {
+ layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToInputWeights) : nullptr;
+ }
layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToForgetWeights) : nullptr;
layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
@@ -209,8 +215,6 @@ void LstmLayer::ValidateTensorShapesFromInputs()
"LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
BOOST_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
"LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled.");
- BOOST_ASSERT_MSG(m_CifgParameters.m_CellToInputWeights == nullptr,
- "LstmLayer: m_CifgParameters.m_CellToInputWeights should not have a value when CIFG is enabled.");
BOOST_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
"LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
@@ -228,6 +232,12 @@ void LstmLayer::ValidateTensorShapesFromInputs()
if (m_Param.m_PeepholeEnabled)
{
+ if (!m_Param.m_CifgEnabled)
+ {
+ BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
+ "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
+ "when Peephole is enabled and CIFG is disabled.");
+ }
BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
"LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
@@ -278,7 +288,6 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
// Cifg parameters
m_CifgParameters.m_InputToInputWeights,
m_CifgParameters.m_RecurrentToInputWeights,
- m_CifgParameters.m_CellToInputWeights,
m_CifgParameters.m_InputGateBias,
// Projection parameters
@@ -286,6 +295,7 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
m_ProjectionParameters.m_ProjectionBias,
// Peephole parameters
+ m_PeepholeParameters.m_CellToInputWeights,
m_PeepholeParameters.m_CellToForgetWeights,
m_PeepholeParameters.m_CellToOutputWeights,
@@ -368,10 +378,10 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const
inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
}
ConstTensor cellToInputWeightsTensor;
- if (m_CifgParameters.m_CellToInputWeights != nullptr)
+ if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
{
- ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_CellToInputWeights->Map(true));
+ ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToInputWeights->Map(true));
cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
}