diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-07-03 18:20:40 +0100 |
---|---|---|
committer | Nikhil Raj Arm <nikhil.raj@arm.com> | 2019-07-09 11:22:28 +0000 |
commit | d01a83c8de77c44a938a618918d17385da3baa88 (patch) | |
tree | fca6f5422adfbdcce059049b36d32e0168edcef4 /src/backends/backendsCommon/WorkloadFactory.cpp | |
parent | e6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff) | |
download | armnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz |
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461
Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 74 |
1 files changed, 27 insertions, 47 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 3502c381e8..1c23e1774b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& outputGateBias = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); - // Optional parameters - const TensorInfo* inputToInputWeights = nullptr; - const TensorInfo* recurrentToInputWeights = nullptr; - const TensorInfo* cellToInputWeights = nullptr; - const TensorInfo* inputGateBias = nullptr; - const TensorInfo* projectionWeights = nullptr; - const TensorInfo* projectionBias = nullptr; - const TensorInfo* cellToForgetWeights = nullptr; - const TensorInfo* cellToOutputWeights = nullptr; - const TensorInfo* inputLayerNormWeights = nullptr; - const TensorInfo* forgetLayerNormWeights = nullptr; - const TensorInfo* cellLayerNormWeights = nullptr; - const TensorInfo* outputLayerNormWeights = nullptr; + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters TensorInfo optInputToInputWeights; TensorInfo optRecurrentToInputWeights; TensorInfo optCellToInputWeights; @@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optInputToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - inputToInputWeights = &optInputToInputWeights; + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - recurrentToInputWeights = &optRecurrentToInputWeights; + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) { optCellToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - cellToInputWeights = &optCellToInputWeights; + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); - inputGateBias = &optInputGateBias; + paramsInfo.m_InputGateBias = &optInputGateBias; } if(descriptor.m_ProjectionEnabled) { optProjectionWeights = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); - projectionWeights = &optProjectionWeights; + paramsInfo.m_ProjectionWeights = &optProjectionWeights; if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) { optProjectionBias = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); - projectionBias = &optProjectionBias; + paramsInfo.m_ProjectionBias = &optProjectionBias; } } @@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); - cellToForgetWeights = &optCellToForgetWeights; + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; optCellToOutputWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); - cellToOutputWeights = &optCellToOutputWeights; + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; } if(descriptor.m_LayerNormEnabled) { optInputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - inputLayerNormWeights = &optInputLayerNormWeights; + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); - forgetLayerNormWeights = &optForgetLayerNormWeights; + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; optCellLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); - cellLayerNormWeights = &optCellLayerNormWeights; + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; optOutputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); - outputLayerNormWeights = &optOutputLayerNormWeights; + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; } result = layerSupportObject->IsLstmSupported( @@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reason, - inputLayerNormWeights, - forgetLayerNormWeights, - cellLayerNormWeights, - outputLayerNormWeights); + paramsInfo, + reason); break; } case LayerType::Maximum: |