diff options
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 28 | ||||
-rw-r--r-- | src/armnn/layers/LstmLayer.hpp | 4 |
2 files changed, 21 insertions, 11 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 581ba45c5f..1d945690d5 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -39,7 +39,6 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& fac { descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get(); descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get(); - descriptor.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights.get(); descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get(); } @@ -53,6 +52,10 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& fac // Peephole parameters if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get(); + } descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get(); descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get(); } @@ -102,8 +105,6 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputToInputWeights) : nullptr; layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ? std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_RecurrentToInputWeights) : nullptr; - layer->m_CifgParameters.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights ? - std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_CellToInputWeights) : nullptr; layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ? std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputGateBias) : nullptr; } @@ -118,6 +119,11 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ? + std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToInputWeights) : nullptr; + } layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ? std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToForgetWeights) : nullptr; layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ? @@ -209,8 +215,6 @@ void LstmLayer::ValidateTensorShapesFromInputs() "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled."); - BOOST_ASSERT_MSG(m_CifgParameters.m_CellToInputWeights == nullptr, - "LstmLayer: m_CifgParameters.m_CellToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled."); @@ -228,6 +232,12 @@ void LstmLayer::ValidateTensorShapesFromInputs() if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, + "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null."); BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, @@ -278,7 +288,6 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() // Cifg parameters m_CifgParameters.m_InputToInputWeights, m_CifgParameters.m_RecurrentToInputWeights, - m_CifgParameters.m_CellToInputWeights, m_CifgParameters.m_InputGateBias, // Projection parameters @@ -286,6 +295,7 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() m_ProjectionParameters.m_ProjectionBias, // Peephole parameters + m_PeepholeParameters.m_CellToInputWeights, m_PeepholeParameters.m_CellToForgetWeights, m_PeepholeParameters.m_CellToOutputWeights, @@ -368,10 +378,10 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } ConstTensor cellToInputWeightsTensor; - if (m_CifgParameters.m_CellToInputWeights != nullptr) + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), - m_CifgParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToInputWeights->Map(true)); cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index 21421f220f..5ccb4bcf92 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -30,8 +30,6 @@ struct LstmOptCifgParameters /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBias; }; @@ -46,6 +44,8 @@ struct LstmOptProjectionParameters struct LstmOptPeepholeParameters { /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr<ScopedCpuTensorHandle> m_CellToForgetWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr<ScopedCpuTensorHandle> m_CellToOutputWeights; |