diff options
author | Ferran Balaguer <ferran.balaguer@arm.com> | 2019-07-02 17:34:46 +0100 |
---|---|---|
committer | Ferran Balaguer <ferran.balaguer@arm.com> | 2019-07-09 12:48:54 +0100 |
commit | 177fa0ba936eaf9de96f04bb91aa51d7656dd655 (patch) | |
tree | c2232336657d93a13f0359ed63a9f6555d4519d2 /1.0 | |
parent | 44381518586476ce7aef78b00bc6a905ddf5730a (diff) | |
download | android-nn-driver-177fa0ba936eaf9de96f04bb91aa51d7656dd655.tar.gz |
IVGCVSW-3396 Support joined lstm parameters
!armnn:1470
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I67a393c1556f0b3022436e41f82f2bf1ab3a1d40
Diffstat (limited to '1.0')
-rw-r--r-- | 1.0/HalPolicy.cpp | 63 |
1 files changed, 19 insertions, 44 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index 13c93277..9673a74c 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -874,50 +874,41 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); // Basic parameters - const armnn::TensorInfo& inputToForgetWeights = params.m_InputToForgetWeights->GetInfo(); - const armnn::TensorInfo& inputToCellWeights = params.m_InputToCellWeights->GetInfo(); - const armnn::TensorInfo& inputToOutputWeights = params.m_InputToOutputWeights->GetInfo(); - const armnn::TensorInfo& recurrentToForgetWeights = params.m_RecurrentToForgetWeights->GetInfo(); - const armnn::TensorInfo& recurrentToCellWeights = params.m_RecurrentToCellWeights->GetInfo(); - const armnn::TensorInfo& recurrentToOutputWeights = params.m_RecurrentToOutputWeights->GetInfo(); - const armnn::TensorInfo& forgetGateBias = params.m_ForgetGateBias->GetInfo(); - const armnn::TensorInfo& cellBias = params.m_CellBias->GetInfo(); - const armnn::TensorInfo& outputGateBias = params.m_OutputGateBias->GetInfo(); - - //Optional parameters - const armnn::TensorInfo* inputToInputWeights = nullptr; - const armnn::TensorInfo* recurrentToInputWeights = nullptr; - const armnn::TensorInfo* cellToInputWeights = nullptr; - const armnn::TensorInfo* inputGateBias = nullptr; - const armnn::TensorInfo* projectionWeights = nullptr; - const armnn::TensorInfo* projectionBias = nullptr; - const armnn::TensorInfo* cellToForgetWeights = nullptr; - const armnn::TensorInfo* cellToOutputWeights = nullptr; + armnn::LstmInputParamsInfo paramsInfo; + paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo()); + paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo()); + paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo()); + paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo()); + paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo()); + paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo()); + paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo()); + paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo()); + paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo()); if(!desc.m_CifgEnabled) { - inputToInputWeights = &(params.m_InputToInputWeights->GetInfo()); - recurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo()); + paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo()); + paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo()); if (params.m_CellToInputWeights != nullptr) { - cellToInputWeights = &(params.m_CellToInputWeights->GetInfo()); + paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo()); } - inputGateBias = &(params.m_InputGateBias->GetInfo()); + paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo()); } if(desc.m_ProjectionEnabled) { - projectionWeights = &(params.m_ProjectionWeights->GetInfo()); + paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo()); if (params.m_ProjectionBias != nullptr) { - projectionBias = &(params.m_ProjectionBias->GetInfo()); + paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo()); } } if(desc.m_PeepholeEnabled) { - cellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo()); - cellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo()); + paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo()); + paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo()); } if (!IsLayerSupportedForAnyBackend(__func__, @@ -931,23 +922,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv cellStateOutInfo, outputInfo, desc, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights)) + paramsInfo)) { return false; } |