aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-03 18:20:40 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 11:22:28 +0000
commitd01a83c8de77c44a938a618918d17385da3baa88 (patch)
treefca6f5422adfbdcce059049b36d32e0168edcef4 /src/backends/cl
parente6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff)
downloadarmnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/cl')
-rw-r--r--src/backends/cl/ClLayerSupport.cpp42
-rw-r--r--src/backends/cl/ClLayerSupport.hpp24
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.cpp72
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.hpp17
4 files changed, 33 insertions, 122 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 497a6435df..6d9b197679 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -405,28 +405,8 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported,
- const TensorInfo* inputLayerNormWeights,
- const TensorInfo* forgetLayerNormWeights,
- const TensorInfo* cellLayerNormWeights,
- const TensorInfo* outputLayerNormWeights) const
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
{
FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
reasonIfUnsupported,
@@ -438,23 +418,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
cellStateOut,
output,
descriptor,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights);
+ paramsInfo);
}
bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 4a55997004..63a4daf864 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -114,28 +114,8 @@ public:
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
- const TensorInfo* inputLayerNormWeights = nullptr,
- const TensorInfo* forgetLayerNormWeights = nullptr,
- const TensorInfo* cellLayerNormWeights = nullptr,
- const TensorInfo* outputLayerNormWeights = nullptr) const override;
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index f4d8974226..3dbbbc3784 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -224,22 +224,7 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights)
+ const LstmInputParamsInfo& paramsInfo)
{
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
@@ -253,18 +238,21 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
// Basic parameters
- const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
- const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
- const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+ const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+ const arm_compute::TensorInfo aclInputToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+ const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToCellWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToOutputWeights);
- const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
- const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
- const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+ const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+ const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+ const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
arm_compute::TensorInfo aclInputToInputWeightsInfo;
arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -277,43 +265,37 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
if (!descriptor.m_CifgEnabled)
{
- armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
- aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
- armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
- aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
+ aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+ aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
- if (cellToInputWeights != nullptr)
+ if (paramsInfo.m_CellToInputWeights != nullptr)
{
- armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
- aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+ aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
}
- armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
- aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+ aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
- cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+ paramsInfo.m_CellToInputWeights != nullptr ?
+ &aclCellToInputWeightsInfo: nullptr,
&aclInputGateBiasInfo);
}
if (descriptor.m_ProjectionEnabled)
{
- const armnn::TensorInfo& projectionWInfo = *projectionWeights;
- aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
+ aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
- if (projectionBias != nullptr)
+ if (paramsInfo.m_ProjectionBias != nullptr)
{
- const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
- aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+ aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
}
lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
- projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+ paramsInfo.m_ProjectionBias != nullptr ?
+ &aclProjectionBiasInfo: nullptr);
}
if (descriptor.m_PeepholeEnabled)
{
- const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
- aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
- const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
- aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+ aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+ aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
}
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index 6a0c41fae3..9a3211a037 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -49,20 +49,5 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor &descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights);
+ const LstmInputParamsInfo& paramsInfo);
} //namespace armnn