diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/LayerSupport.cpp | 18 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 24 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 24 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 74 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 42 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.hpp | 24 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClLstmFloatWorkload.cpp | 72 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClLstmFloatWorkload.hpp | 17 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 123 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 24 |
10 files changed, 142 insertions, 300 deletions
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index b2ca85c04e..a2908aae33 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -333,27 +333,13 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported, + const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported, size_t reasonIfUnsupportedMaxLength) { FORWARD_LAYER_SUPPORT_FUNC(backend, IsLstmSupported, input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut, - output, descriptor, inputToForgetWeights, inputToCellWeights, - inputToOutputWeights, recurrentToForgetWeights, - recurrentToCellWeights, recurrentToOutputWeights, - forgetGateBias, cellBias, outputGateBias, - inputToInputWeights, recurrentToInputWeights, - cellToInputWeights, inputGateBias, projectionWeights, - projectionBias, cellToForgetWeights, cellToOutputWeights); + output, descriptor, paramsInfo); } bool IsMaximumSupported(const BackendId& backend, diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 4488e25c9c..ea22fac9ce 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -226,28 +226,8 @@ bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 03a928a706..36b8e77c38 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -140,28 +140,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 3502c381e8..1c23e1774b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& outputGateBias = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); - // Optional parameters - const TensorInfo* inputToInputWeights = nullptr; - const TensorInfo* recurrentToInputWeights = nullptr; - const TensorInfo* cellToInputWeights = nullptr; - const TensorInfo* inputGateBias = nullptr; - const TensorInfo* projectionWeights = nullptr; - const TensorInfo* projectionBias = nullptr; - const TensorInfo* cellToForgetWeights = nullptr; - const TensorInfo* cellToOutputWeights = nullptr; - const TensorInfo* inputLayerNormWeights = nullptr; - const TensorInfo* forgetLayerNormWeights = nullptr; - const TensorInfo* cellLayerNormWeights = nullptr; - const TensorInfo* outputLayerNormWeights = nullptr; + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters TensorInfo optInputToInputWeights; TensorInfo optRecurrentToInputWeights; TensorInfo optCellToInputWeights; @@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optInputToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - inputToInputWeights = &optInputToInputWeights; + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - recurrentToInputWeights = &optRecurrentToInputWeights; + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) { optCellToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - cellToInputWeights = &optCellToInputWeights; + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); - inputGateBias = &optInputGateBias; + paramsInfo.m_InputGateBias = &optInputGateBias; } if(descriptor.m_ProjectionEnabled) { optProjectionWeights = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); - projectionWeights = &optProjectionWeights; + paramsInfo.m_ProjectionWeights = &optProjectionWeights; if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) { optProjectionBias = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); - projectionBias = &optProjectionBias; + paramsInfo.m_ProjectionBias = &optProjectionBias; } } @@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); - cellToForgetWeights = &optCellToForgetWeights; + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; optCellToOutputWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); - cellToOutputWeights = &optCellToOutputWeights; + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; } if(descriptor.m_LayerNormEnabled) { optInputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - inputLayerNormWeights = &optInputLayerNormWeights; + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); - forgetLayerNormWeights = &optForgetLayerNormWeights; + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; optCellLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); - cellLayerNormWeights = &optCellLayerNormWeights; + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; optOutputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); - outputLayerNormWeights = &optOutputLayerNormWeights; + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; } result = layerSupportObject->IsLstmSupported( @@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reason, - inputLayerNormWeights, - forgetLayerNormWeights, - cellLayerNormWeights, - outputLayerNormWeights); + paramsInfo, + reason); break; } case LayerType::Maximum: diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 497a6435df..6d9b197679 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -405,28 +405,8 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported, @@ -438,23 +418,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights); + paramsInfo); } bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 4a55997004..63a4daf864 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -114,28 +114,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp index f4d8974226..3dbbbc3784 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp @@ -224,22 +224,7 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights) + const LstmInputParamsInfo& paramsInfo) { arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info; @@ -253,18 +238,21 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); // Basic parameters - const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights); - const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights); - const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo - = BuildArmComputeTensorInfo(recurrentToForgetWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo - = BuildArmComputeTensorInfo(recurrentToCellWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo - = BuildArmComputeTensorInfo(recurrentToOutputWeights); - const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias); - const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias); - const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); arm_compute::TensorInfo aclInputToInputWeightsInfo; arm_compute::TensorInfo aclRecurrentToInputWeightsInfo; @@ -277,43 +265,37 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T if (!descriptor.m_CifgEnabled) { - armnn::TensorInfo inputToInputWInfo = *inputToInputWeights; - aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo); - armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights; - aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo); + aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); - if (cellToInputWeights != nullptr) + if (paramsInfo.m_CellToInputWeights != nullptr) { - armnn::TensorInfo cellToInputWInfo = *cellToInputWeights; - aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo); + aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights()); } - armnn::TensorInfo inputGateBiasInfo = *inputGateBias; - aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo); + aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo, - cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr, + paramsInfo.m_CellToInputWeights != nullptr ? + &aclCellToInputWeightsInfo: nullptr, &aclInputGateBiasInfo); } if (descriptor.m_ProjectionEnabled) { - const armnn::TensorInfo& projectionWInfo = *projectionWeights; - aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo); + aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights()); - if (projectionBias != nullptr) + if (paramsInfo.m_ProjectionBias != nullptr) { - const armnn::TensorInfo& projectionBiasInfo = *projectionBias; - aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo); + aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); } lstm_params_info.set_projection_params(&aclProjectionWeightsInfo, - projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr); + paramsInfo.m_ProjectionBias != nullptr ? + &aclProjectionBiasInfo: nullptr); } if (descriptor.m_PeepholeEnabled) { - const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights; - aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo); - const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights; - aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo); + aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights()); + aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights()); lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo); } diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp index 6a0c41fae3..9a3211a037 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp @@ -49,20 +49,5 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor &descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights); + const LstmInputParamsInfo& paramsInfo); } //namespace armnn diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ac7f310c1d..59c14c4490 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { ignore_unused(descriptor); - ignore_unused(inputToForgetWeights); - ignore_unused(inputToCellWeights); - ignore_unused(inputToOutputWeights); - ignore_unused(recurrentToForgetWeights); - ignore_unused(recurrentToCellWeights); - ignore_unused(recurrentToOutputWeights); - ignore_unused(forgetGateBias); - ignore_unused(cellBias); - ignore_unused(outputGateBias); - ignore_unused(inputToInputWeights); - ignore_unused(recurrentToInputWeights); - ignore_unused(cellToInputWeights); - ignore_unused(inputGateBias); - ignore_unused(projectionWeights); - ignore_unused(projectionBias); - ignore_unused(cellToForgetWeights); - ignore_unused(cellToOutputWeights); - ignore_unused(inputLayerNormWeights); - ignore_unused(forgetLayerNormWeights); - ignore_unused(cellLayerNormWeights); - ignore_unused(outputLayerNormWeights); + ignore_unused(paramsInfo); bool supported = true; @@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, DataType::QuantisedSymm16 }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference Lstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, "Reference Lstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, "Reference Lstm: input and cellStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, "Reference Lstm: input and scratchBuffer types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, "Reference Lstm: input and outputStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported, + "Reference Lstm: input and ForgetGateBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported, + "Reference Lstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and OutputGateBias types are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and RecurrentToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and InputGateBias types are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellToInputWeights types are mismatched"); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToOutputWeights types are mismatched"); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported, + "Reference Lstm: input and mProjectionWeights types are mismatched"); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported, + "Reference Lstm: input and ProjectionBias types are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and InputLayerNormWeights types are mismatched"); + } + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and ForgetLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and OutputLayerNormWeights types are mismatched"); + } return supported; } diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index ead4d1ce4a..c0bf18824e 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -138,28 +138,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, |