diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-07-03 18:20:40 +0100 |
---|---|---|
committer | Nikhil Raj Arm <nikhil.raj@arm.com> | 2019-07-09 11:22:28 +0000 |
commit | d01a83c8de77c44a938a618918d17385da3baa88 (patch) | |
tree | fca6f5422adfbdcce059049b36d32e0168edcef4 /src/backends/backendsCommon | |
parent | e6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff) | |
download | armnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz |
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461
Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 24 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 24 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 74 |
3 files changed, 31 insertions, 91 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 4488e25c9c..ea22fac9ce 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -226,28 +226,8 @@ bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 03a928a706..36b8e77c38 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -140,28 +140,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 3502c381e8..1c23e1774b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& outputGateBias = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); - // Optional parameters - const TensorInfo* inputToInputWeights = nullptr; - const TensorInfo* recurrentToInputWeights = nullptr; - const TensorInfo* cellToInputWeights = nullptr; - const TensorInfo* inputGateBias = nullptr; - const TensorInfo* projectionWeights = nullptr; - const TensorInfo* projectionBias = nullptr; - const TensorInfo* cellToForgetWeights = nullptr; - const TensorInfo* cellToOutputWeights = nullptr; - const TensorInfo* inputLayerNormWeights = nullptr; - const TensorInfo* forgetLayerNormWeights = nullptr; - const TensorInfo* cellLayerNormWeights = nullptr; - const TensorInfo* outputLayerNormWeights = nullptr; + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters TensorInfo optInputToInputWeights; TensorInfo optRecurrentToInputWeights; TensorInfo optCellToInputWeights; @@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optInputToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - inputToInputWeights = &optInputToInputWeights; + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - recurrentToInputWeights = &optRecurrentToInputWeights; + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) { optCellToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - cellToInputWeights = &optCellToInputWeights; + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); - inputGateBias = &optInputGateBias; + paramsInfo.m_InputGateBias = &optInputGateBias; } if(descriptor.m_ProjectionEnabled) { optProjectionWeights = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); - projectionWeights = &optProjectionWeights; + paramsInfo.m_ProjectionWeights = &optProjectionWeights; if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) { optProjectionBias = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); - projectionBias = &optProjectionBias; + paramsInfo.m_ProjectionBias = &optProjectionBias; } } @@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); - cellToForgetWeights = &optCellToForgetWeights; + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; optCellToOutputWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); - cellToOutputWeights = &optCellToOutputWeights; + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; } if(descriptor.m_LayerNormEnabled) { optInputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - inputLayerNormWeights = &optInputLayerNormWeights; + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); - forgetLayerNormWeights = &optForgetLayerNormWeights; + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; optCellLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); - cellLayerNormWeights = &optCellLayerNormWeights; + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; optOutputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); - outputLayerNormWeights = &optOutputLayerNormWeights; + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; } result = layerSupportObject->IsLstmSupported( @@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reason, - inputLayerNormWeights, - forgetLayerNormWeights, - cellLayerNormWeights, - outputLayerNormWeights); + paramsInfo, + reason); break; } case LayerType::Maximum: |