diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 74 |
1 files changed, 27 insertions, 47 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 3502c381e8..1c23e1774b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& outputGateBias = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); - // Optional parameters - const TensorInfo* inputToInputWeights = nullptr; - const TensorInfo* recurrentToInputWeights = nullptr; - const TensorInfo* cellToInputWeights = nullptr; - const TensorInfo* inputGateBias = nullptr; - const TensorInfo* projectionWeights = nullptr; - const TensorInfo* projectionBias = nullptr; - const TensorInfo* cellToForgetWeights = nullptr; - const TensorInfo* cellToOutputWeights = nullptr; - const TensorInfo* inputLayerNormWeights = nullptr; - const TensorInfo* forgetLayerNormWeights = nullptr; - const TensorInfo* cellLayerNormWeights = nullptr; - const TensorInfo* outputLayerNormWeights = nullptr; + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters TensorInfo optInputToInputWeights; TensorInfo optRecurrentToInputWeights; TensorInfo optCellToInputWeights; @@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optInputToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - inputToInputWeights = &optInputToInputWeights; + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - recurrentToInputWeights = &optRecurrentToInputWeights; + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) { optCellToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - cellToInputWeights = &optCellToInputWeights; + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); - inputGateBias = &optInputGateBias; + paramsInfo.m_InputGateBias = &optInputGateBias; } if(descriptor.m_ProjectionEnabled) { optProjectionWeights = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); - projectionWeights = &optProjectionWeights; + paramsInfo.m_ProjectionWeights = &optProjectionWeights; if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) { optProjectionBias = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); - projectionBias = &optProjectionBias; + paramsInfo.m_ProjectionBias = &optProjectionBias; } } @@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); - cellToForgetWeights = &optCellToForgetWeights; + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; optCellToOutputWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); - cellToOutputWeights = &optCellToOutputWeights; + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; } if(descriptor.m_LayerNormEnabled) { optInputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - inputLayerNormWeights = &optInputLayerNormWeights; + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); - forgetLayerNormWeights = &optForgetLayerNormWeights; + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; optCellLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); - cellLayerNormWeights = &optCellLayerNormWeights; + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; optOutputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); - outputLayerNormWeights = &optOutputLayerNormWeights; + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; } result = layerSupportObject->IsLstmSupported( @@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reason, - inputLayerNormWeights, - forgetLayerNormWeights, - cellLayerNormWeights, - outputLayerNormWeights); + paramsInfo, + reason); break; } case LayerType::Maximum: |