aboutsummaryrefslogtreecommitdiff
path: root/1.0/HalPolicy.cpp
diff options
context:
space:
mode:
Diffstat (limited to '1.0/HalPolicy.cpp')
-rw-r--r--1.0/HalPolicy.cpp63
1 files changed, 19 insertions, 44 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 13c93277..9673a74c 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -874,50 +874,41 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
// Basic parameters
- const armnn::TensorInfo& inputToForgetWeights = params.m_InputToForgetWeights->GetInfo();
- const armnn::TensorInfo& inputToCellWeights = params.m_InputToCellWeights->GetInfo();
- const armnn::TensorInfo& inputToOutputWeights = params.m_InputToOutputWeights->GetInfo();
- const armnn::TensorInfo& recurrentToForgetWeights = params.m_RecurrentToForgetWeights->GetInfo();
- const armnn::TensorInfo& recurrentToCellWeights = params.m_RecurrentToCellWeights->GetInfo();
- const armnn::TensorInfo& recurrentToOutputWeights = params.m_RecurrentToOutputWeights->GetInfo();
- const armnn::TensorInfo& forgetGateBias = params.m_ForgetGateBias->GetInfo();
- const armnn::TensorInfo& cellBias = params.m_CellBias->GetInfo();
- const armnn::TensorInfo& outputGateBias = params.m_OutputGateBias->GetInfo();
-
- //Optional parameters
- const armnn::TensorInfo* inputToInputWeights = nullptr;
- const armnn::TensorInfo* recurrentToInputWeights = nullptr;
- const armnn::TensorInfo* cellToInputWeights = nullptr;
- const armnn::TensorInfo* inputGateBias = nullptr;
- const armnn::TensorInfo* projectionWeights = nullptr;
- const armnn::TensorInfo* projectionBias = nullptr;
- const armnn::TensorInfo* cellToForgetWeights = nullptr;
- const armnn::TensorInfo* cellToOutputWeights = nullptr;
+ armnn::LstmInputParamsInfo paramsInfo;
+ paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
+ paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
+ paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
+ paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
+ paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
+ paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
+ paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
+ paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
+ paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
if(!desc.m_CifgEnabled)
{
- inputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
- recurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
+ paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
+ paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
if (params.m_CellToInputWeights != nullptr)
{
- cellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
+ paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
}
- inputGateBias = &(params.m_InputGateBias->GetInfo());
+ paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
}
if(desc.m_ProjectionEnabled)
{
- projectionWeights = &(params.m_ProjectionWeights->GetInfo());
+ paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
if (params.m_ProjectionBias != nullptr)
{
- projectionBias = &(params.m_ProjectionBias->GetInfo());
+ paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
}
}
if(desc.m_PeepholeEnabled)
{
- cellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
- cellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
+ paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
+ paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
}
if (!IsLayerSupportedForAnyBackend(__func__,
@@ -931,23 +922,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv
cellStateOutInfo,
outputInfo,
desc,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights))
+ paramsInfo))
{
return false;
}