aboutsummaryrefslogtreecommitdiff
path: root/1.0
diff options
context:
space:
mode:
authorFerran Balaguer <ferran.balaguer@arm.com>2019-07-02 17:34:46 +0100
committerFerran Balaguer <ferran.balaguer@arm.com>2019-07-09 12:48:54 +0100
commit177fa0ba936eaf9de96f04bb91aa51d7656dd655 (patch)
treec2232336657d93a13f0359ed63a9f6555d4519d2 /1.0
parent44381518586476ce7aef78b00bc6a905ddf5730a (diff)
downloadandroid-nn-driver-177fa0ba936eaf9de96f04bb91aa51d7656dd655.tar.gz
IVGCVSW-3396 Support joined lstm parameters
!armnn:1470 Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com> Change-Id: I67a393c1556f0b3022436e41f82f2bf1ab3a1d40
Diffstat (limited to '1.0')
-rw-r--r--1.0/HalPolicy.cpp63
1 files changed, 19 insertions, 44 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 13c93277..9673a74c 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -874,50 +874,41 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
// Basic parameters
- const armnn::TensorInfo& inputToForgetWeights = params.m_InputToForgetWeights->GetInfo();
- const armnn::TensorInfo& inputToCellWeights = params.m_InputToCellWeights->GetInfo();
- const armnn::TensorInfo& inputToOutputWeights = params.m_InputToOutputWeights->GetInfo();
- const armnn::TensorInfo& recurrentToForgetWeights = params.m_RecurrentToForgetWeights->GetInfo();
- const armnn::TensorInfo& recurrentToCellWeights = params.m_RecurrentToCellWeights->GetInfo();
- const armnn::TensorInfo& recurrentToOutputWeights = params.m_RecurrentToOutputWeights->GetInfo();
- const armnn::TensorInfo& forgetGateBias = params.m_ForgetGateBias->GetInfo();
- const armnn::TensorInfo& cellBias = params.m_CellBias->GetInfo();
- const armnn::TensorInfo& outputGateBias = params.m_OutputGateBias->GetInfo();
-
- //Optional parameters
- const armnn::TensorInfo* inputToInputWeights = nullptr;
- const armnn::TensorInfo* recurrentToInputWeights = nullptr;
- const armnn::TensorInfo* cellToInputWeights = nullptr;
- const armnn::TensorInfo* inputGateBias = nullptr;
- const armnn::TensorInfo* projectionWeights = nullptr;
- const armnn::TensorInfo* projectionBias = nullptr;
- const armnn::TensorInfo* cellToForgetWeights = nullptr;
- const armnn::TensorInfo* cellToOutputWeights = nullptr;
+ armnn::LstmInputParamsInfo paramsInfo;
+ paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
+ paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
+ paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
+ paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
+ paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
+ paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
+ paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
+ paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
+ paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
if(!desc.m_CifgEnabled)
{
- inputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
- recurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
+ paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
+ paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
if (params.m_CellToInputWeights != nullptr)
{
- cellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
+ paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
}
- inputGateBias = &(params.m_InputGateBias->GetInfo());
+ paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
}
if(desc.m_ProjectionEnabled)
{
- projectionWeights = &(params.m_ProjectionWeights->GetInfo());
+ paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
if (params.m_ProjectionBias != nullptr)
{
- projectionBias = &(params.m_ProjectionBias->GetInfo());
+ paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
}
}
if(desc.m_PeepholeEnabled)
{
- cellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
- cellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
+ paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
+ paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
}
if (!IsLayerSupportedForAnyBackend(__func__,
@@ -931,23 +922,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
cellStateOutInfo,
outputInfo,
desc,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights))
+ paramsInfo))
{
return false;
}