diff options
author | jimfly01 <jim.flynn@arm.com> | 2019-01-28 12:51:53 +0000 |
---|---|---|
committer | jimfly01 <jim.flynn@arm.com> | 2019-01-29 10:37:47 +0000 |
commit | d161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34 (patch) | |
tree | 909d956ede3aaaf2812d4141a4742c4e2c936122 /src/armnn/layers/LstmLayer.cpp | |
parent | c6a41ffa25d468a69465e1a7b22b280b029f65a2 (diff) | |
download | armnn-d161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34.tar.gz |
IVGCVSW-2569 Add implementation of ConstTensor Accept functions
* Create the required ConstTensors and pass them to the appropriate
visit method. Back fill of dummies added during IVGCVSW-2547
* Moved the VisitDetectionPostProcessLayer function declaration in
ILayerVistor to its correct location after the
VisitDepthwiseConvolution2dLayer functions.
Change-Id: I0bd2f8c3603cbdb933b1216ead96dd8273eb5013
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 111 |
1 files changed, 109 insertions, 2 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 942038a315..06140c924f 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -251,8 +251,115 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() void LstmLayer::Accept(ILayerVisitor& visitor) const { - LstmInputParams dummy; - visitor.VisitLstmLayer(this, GetParameters(), dummy, GetName()); + LstmInputParams inputParams; + if (m_CifgParameters.m_InputToInputWeights != nullptr) + { + ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), + m_CifgParameters.m_InputToInputWeights->GetConstTensor<void*>()); + inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; + } + if (m_BasicParameters.m_InputToForgetWeights != nullptr) + { + ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_InputToForgetWeights->GetConstTensor<void*>()); + inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; + } + if (m_BasicParameters.m_InputToCellWeights != nullptr) + { + ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), + m_BasicParameters.m_InputToCellWeights->GetConstTensor<void*>()); + inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; + } + if (m_BasicParameters.m_InputToOutputWeights != nullptr) + { + ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_InputToOutputWeights->GetConstTensor<void*>()); + inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; + } + if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) + { + ConstTensor recurrentToInputWeightsTensor( + m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), + m_CifgParameters.m_RecurrentToInputWeights->GetConstTensor<void*>()); + inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) + { + ConstTensor recurrentToForgetWeightsTensor( + m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToForgetWeights->GetConstTensor<void*>()); + inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) + { + ConstTensor recurrentToCellWeightsTensor( + m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToCellWeights->GetConstTensor<void*>()); + inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) + { + ConstTensor recurrentToOutputWeightsTensor( + m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToOutputWeights->GetConstTensor<void*>()); + inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + } + if (m_CifgParameters.m_CellToInputWeights != nullptr) + { + ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), + m_CifgParameters.m_CellToInputWeights->GetConstTensor<void*>()); + inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; + } + if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) + { + ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToForgetWeights->GetConstTensor<void*>()); + inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; + } + if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) + { + ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToOutputWeights->GetConstTensor<void*>()); + inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; + } + if (m_CifgParameters.m_InputGateBias != nullptr) + { + ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), + m_CifgParameters.m_InputGateBias->GetConstTensor<void*>()); + inputParams.m_InputGateBias = &inputGateBiasTensor; + } + if (m_BasicParameters.m_ForgetGateBias != nullptr) + { + ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), + m_BasicParameters.m_ForgetGateBias->GetConstTensor<void*>()); + inputParams.m_ForgetGateBias = &forgetGateBiasTensor; + } + if (m_BasicParameters.m_CellBias != nullptr) + { + ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), + m_BasicParameters.m_CellBias->GetConstTensor<void*>()); + inputParams.m_CellBias = &cellBiasTensor; + } + if (m_BasicParameters.m_OutputGateBias != nullptr) + { + ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), + m_BasicParameters.m_OutputGateBias->GetConstTensor<void*>()); + inputParams.m_OutputGateBias = &outputGateBias; + } + if (m_ProjectionParameters.m_ProjectionWeights != nullptr) + { + ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionWeights->GetConstTensor<void*>()); + inputParams.m_ProjectionWeights = &projectionWeightsTensor; + } + if (m_ProjectionParameters.m_ProjectionBias != nullptr) + { + ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionBias->GetConstTensor<void*>()); + inputParams.m_ProjectionBias = &projectionBiasTensor; + } + + visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName()); } } // namespace armnn |