aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/LstmLayer.cpp
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2019-01-28 12:51:53 +0000
committerjimfly01 <jim.flynn@arm.com>2019-01-29 10:37:47 +0000
commitd161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34 (patch)
tree909d956ede3aaaf2812d4141a4742c4e2c936122 /src/armnn/layers/LstmLayer.cpp
parentc6a41ffa25d468a69465e1a7b22b280b029f65a2 (diff)
downloadarmnn-d161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34.tar.gz
IVGCVSW-2569 Add implementation of ConstTensor Accept functions
* Create the required ConstTensors and pass them to the appropriate visit method. Back fill of dummies added during IVGCVSW-2547 * Moved the VisitDetectionPostProcessLayer function declaration in ILayerVistor to its correct location after the VisitDepthwiseConvolution2dLayer functions. Change-Id: I0bd2f8c3603cbdb933b1216ead96dd8273eb5013
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r--src/armnn/layers/LstmLayer.cpp111
1 files changed, 109 insertions, 2 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index 942038a315..06140c924f 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -251,8 +251,115 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
void LstmLayer::Accept(ILayerVisitor& visitor) const
{
- LstmInputParams dummy;
- visitor.VisitLstmLayer(this, GetParameters(), dummy, GetName());
+ LstmInputParams inputParams;
+ if (m_CifgParameters.m_InputToInputWeights != nullptr)
+ {
+ ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_InputToInputWeights->GetConstTensor<void*>());
+ inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
+ }
+ if (m_BasicParameters.m_InputToForgetWeights != nullptr)
+ {
+ ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToForgetWeights->GetConstTensor<void*>());
+ inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
+ }
+ if (m_BasicParameters.m_InputToCellWeights != nullptr)
+ {
+ ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToCellWeights->GetConstTensor<void*>());
+ inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
+ }
+ if (m_BasicParameters.m_InputToOutputWeights != nullptr)
+ {
+ ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToOutputWeights->GetConstTensor<void*>());
+ inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
+ }
+ if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
+ {
+ ConstTensor recurrentToInputWeightsTensor(
+ m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_RecurrentToInputWeights->GetConstTensor<void*>());
+ inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
+ }
+ if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
+ {
+ ConstTensor recurrentToForgetWeightsTensor(
+ m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
+ m_BasicParameters.m_RecurrentToForgetWeights->GetConstTensor<void*>());
+ inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
+ }
+ if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
+ {
+ ConstTensor recurrentToCellWeightsTensor(
+ m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
+ m_BasicParameters.m_RecurrentToCellWeights->GetConstTensor<void*>());
+ inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
+ }
+ if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
+ {
+ ConstTensor recurrentToOutputWeightsTensor(
+ m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
+ m_BasicParameters.m_RecurrentToOutputWeights->GetConstTensor<void*>());
+ inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
+ }
+ if (m_CifgParameters.m_CellToInputWeights != nullptr)
+ {
+ ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_CellToInputWeights->GetConstTensor<void*>());
+ inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
+ }
+ if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
+ {
+ ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToForgetWeights->GetConstTensor<void*>());
+ inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
+ }
+ if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
+ {
+ ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToOutputWeights->GetConstTensor<void*>());
+ inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
+ }
+ if (m_CifgParameters.m_InputGateBias != nullptr)
+ {
+ ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
+ m_CifgParameters.m_InputGateBias->GetConstTensor<void*>());
+ inputParams.m_InputGateBias = &inputGateBiasTensor;
+ }
+ if (m_BasicParameters.m_ForgetGateBias != nullptr)
+ {
+ ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
+ m_BasicParameters.m_ForgetGateBias->GetConstTensor<void*>());
+ inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
+ }
+ if (m_BasicParameters.m_CellBias != nullptr)
+ {
+ ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(),
+ m_BasicParameters.m_CellBias->GetConstTensor<void*>());
+ inputParams.m_CellBias = &cellBiasTensor;
+ }
+ if (m_BasicParameters.m_OutputGateBias != nullptr)
+ {
+ ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
+ m_BasicParameters.m_OutputGateBias->GetConstTensor<void*>());
+ inputParams.m_OutputGateBias = &outputGateBias;
+ }
+ if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
+ {
+ ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionWeights->GetConstTensor<void*>());
+ inputParams.m_ProjectionWeights = &projectionWeightsTensor;
+ }
+ if (m_ProjectionParameters.m_ProjectionBias != nullptr)
+ {
+ ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionBias->GetConstTensor<void*>());
+ inputParams.m_ProjectionBias = &projectionBiasTensor;
+ }
+
+ visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName());
}
} // namespace armnn