From 177fa0ba936eaf9de96f04bb91aa51d7656dd655 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Tue, 2 Jul 2019 17:34:46 +0100 Subject: IVGCVSW-3396 Support joined lstm parameters !armnn:1470 Signed-off-by: Ferran Balaguer Change-Id: I67a393c1556f0b3022436e41f82f2bf1ab3a1d40 --- 1.0/HalPolicy.cpp | 63 +++++++++++++++++-------------------------------------- 1 file changed, 19 insertions(+), 44 deletions(-) (limited to '1.0') diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index 13c93277..9673a74c 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -874,50 +874,41 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); // Basic parameters - const armnn::TensorInfo& inputToForgetWeights = params.m_InputToForgetWeights->GetInfo(); - const armnn::TensorInfo& inputToCellWeights = params.m_InputToCellWeights->GetInfo(); - const armnn::TensorInfo& inputToOutputWeights = params.m_InputToOutputWeights->GetInfo(); - const armnn::TensorInfo& recurrentToForgetWeights = params.m_RecurrentToForgetWeights->GetInfo(); - const armnn::TensorInfo& recurrentToCellWeights = params.m_RecurrentToCellWeights->GetInfo(); - const armnn::TensorInfo& recurrentToOutputWeights = params.m_RecurrentToOutputWeights->GetInfo(); - const armnn::TensorInfo& forgetGateBias = params.m_ForgetGateBias->GetInfo(); - const armnn::TensorInfo& cellBias = params.m_CellBias->GetInfo(); - const armnn::TensorInfo& outputGateBias = params.m_OutputGateBias->GetInfo(); - - //Optional parameters - const armnn::TensorInfo* inputToInputWeights = nullptr; - const armnn::TensorInfo* recurrentToInputWeights = nullptr; - const armnn::TensorInfo* cellToInputWeights = nullptr; - const armnn::TensorInfo* inputGateBias = nullptr; - const armnn::TensorInfo* projectionWeights = nullptr; - const armnn::TensorInfo* projectionBias = nullptr; - const armnn::TensorInfo* cellToForgetWeights = nullptr; - const armnn::TensorInfo* cellToOutputWeights = nullptr; + armnn::LstmInputParamsInfo paramsInfo; + paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo()); + paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo()); + paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo()); + paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo()); + paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo()); + paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo()); + paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo()); + paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo()); + paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo()); if(!desc.m_CifgEnabled) { - inputToInputWeights = &(params.m_InputToInputWeights->GetInfo()); - recurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo()); + paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo()); + paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo()); if (params.m_CellToInputWeights != nullptr) { - cellToInputWeights = &(params.m_CellToInputWeights->GetInfo()); + paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo()); } - inputGateBias = &(params.m_InputGateBias->GetInfo()); + paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo()); } if(desc.m_ProjectionEnabled) { - projectionWeights = &(params.m_ProjectionWeights->GetInfo()); + paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo()); if (params.m_ProjectionBias != nullptr) { - projectionBias = &(params.m_ProjectionBias->GetInfo()); + paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo()); } } if(desc.m_PeepholeEnabled) { - cellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo()); - cellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo()); + paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo()); + paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo()); } if (!IsLayerSupportedForAnyBackend(__func__, @@ -931,23 +922,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv cellStateOutInfo, outputInfo, desc, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights)) + paramsInfo)) { return false; } -- cgit v1.2.1