From e2062cdf1eb31b87860f9889f0e799e89f0dfa30 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 30 Mar 2020 15:07:45 +0100 Subject: IVGCVSW-4590 Fix Lstm layers CellToInputWeights * CellToInputWeights were not handeled correctly * Changed CellToInputWeights from Cifg to peephole parameter * Modified exiting unit tests * Added unit test to cover new configuration * Added more descriptive error messages Signed-off-by: Jan Eilers Change-Id: Ied5dc1253d3df1fd1a79b887a58603d0a9c8f396 --- src/armnn/layers/LstmLayer.cpp | 28 +++++++++++++++++++--------- src/armnn/layers/LstmLayer.hpp | 4 ++-- 2 files changed, 21 insertions(+), 11 deletions(-) (limited to 'src/armnn/layers') diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 581ba45c5f..1d945690d5 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -39,7 +39,6 @@ std::unique_ptr LstmLayer::CreateWorkload(const IWorkloadFactory& fac { descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get(); descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get(); - descriptor.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights.get(); descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get(); } @@ -53,6 +52,10 @@ std::unique_ptr LstmLayer::CreateWorkload(const IWorkloadFactory& fac // Peephole parameters if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get(); + } descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get(); descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get(); } @@ -102,8 +105,6 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const std::make_unique(*m_CifgParameters.m_InputToInputWeights) : nullptr; layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ? std::make_unique(*m_CifgParameters.m_RecurrentToInputWeights) : nullptr; - layer->m_CifgParameters.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights ? - std::make_unique(*m_CifgParameters.m_CellToInputWeights) : nullptr; layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ? std::make_unique(*m_CifgParameters.m_InputGateBias) : nullptr; } @@ -118,6 +119,11 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ? + std::make_unique(*m_PeepholeParameters.m_CellToInputWeights) : nullptr; + } layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ? std::make_unique(*m_PeepholeParameters.m_CellToForgetWeights) : nullptr; layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ? @@ -209,8 +215,6 @@ void LstmLayer::ValidateTensorShapesFromInputs() "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled."); - BOOST_ASSERT_MSG(m_CifgParameters.m_CellToInputWeights == nullptr, - "LstmLayer: m_CifgParameters.m_CellToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled."); @@ -228,6 +232,12 @@ void LstmLayer::ValidateTensorShapesFromInputs() if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, + "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null."); BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, @@ -278,7 +288,6 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() // Cifg parameters m_CifgParameters.m_InputToInputWeights, m_CifgParameters.m_RecurrentToInputWeights, - m_CifgParameters.m_CellToInputWeights, m_CifgParameters.m_InputGateBias, // Projection parameters @@ -286,6 +295,7 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() m_ProjectionParameters.m_ProjectionBias, // Peephole parameters + m_PeepholeParameters.m_CellToInputWeights, m_PeepholeParameters.m_CellToForgetWeights, m_PeepholeParameters.m_CellToOutputWeights, @@ -368,10 +378,10 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } ConstTensor cellToInputWeightsTensor; - if (m_CifgParameters.m_CellToInputWeights != nullptr) + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), - m_CifgParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToInputWeights->Map(true)); cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index 21421f220f..5ccb4bcf92 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -30,8 +30,6 @@ struct LstmOptCifgParameters /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. std::unique_ptr m_RecurrentToInputWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::unique_ptr m_CellToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr m_InputGateBias; }; @@ -45,6 +43,8 @@ struct LstmOptProjectionParameters struct LstmOptPeepholeParameters { + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr m_CellToInputWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr m_CellToForgetWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. -- cgit v1.2.1