aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-03 18:20:40 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 11:22:28 +0000
commitd01a83c8de77c44a938a618918d17385da3baa88 (patch)
treefca6f5422adfbdcce059049b36d32e0168edcef4 /src
parente6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff)
downloadarmnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/armnn/LayerSupport.cpp18
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp24
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp24
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp74
-rw-r--r--src/backends/cl/ClLayerSupport.cpp42
-rw-r--r--src/backends/cl/ClLayerSupport.hpp24
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.cpp72
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.hpp17
-rw-r--r--src/backends/reference/RefLayerSupport.cpp123
-rw-r--r--src/backends/reference/RefLayerSupport.hpp24
10 files changed, 142 insertions, 300 deletions
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index b2ca85c04e..a2908aae33 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -333,27 +333,13 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported,
+ const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported,
size_t reasonIfUnsupportedMaxLength)
{
FORWARD_LAYER_SUPPORT_FUNC(backend, IsLstmSupported, input, outputStateIn, cellStateIn,
scratchBuffer, outputStateOut, cellStateOut,
- output, descriptor, inputToForgetWeights, inputToCellWeights,
- inputToOutputWeights, recurrentToForgetWeights,
- recurrentToCellWeights, recurrentToOutputWeights,
- forgetGateBias, cellBias, outputGateBias,
- inputToInputWeights, recurrentToInputWeights,
- cellToInputWeights, inputGateBias, projectionWeights,
- projectionBias, cellToForgetWeights, cellToOutputWeights);
+ output, descriptor, paramsInfo);
}
bool IsMaximumSupported(const BackendId& backend,
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 4488e25c9c..ea22fac9ce 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -226,28 +226,8 @@ bool LayerSupportBase::IsLstmSupported(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported,
- const TensorInfo* inputLayerNormWeights,
- const TensorInfo* forgetLayerNormWeights,
- const TensorInfo* cellLayerNormWeights,
- const TensorInfo* outputLayerNormWeights) const
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 03a928a706..36b8e77c38 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -140,28 +140,8 @@ public:
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
- const TensorInfo* inputLayerNormWeights = nullptr,
- const TensorInfo* forgetLayerNormWeights = nullptr,
- const TensorInfo* cellLayerNormWeights = nullptr,
- const TensorInfo* outputLayerNormWeights = nullptr) const override;
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 3502c381e8..1c23e1774b 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
const TensorInfo& outputGateBias
= OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
- // Optional parameters
- const TensorInfo* inputToInputWeights = nullptr;
- const TensorInfo* recurrentToInputWeights = nullptr;
- const TensorInfo* cellToInputWeights = nullptr;
- const TensorInfo* inputGateBias = nullptr;
- const TensorInfo* projectionWeights = nullptr;
- const TensorInfo* projectionBias = nullptr;
- const TensorInfo* cellToForgetWeights = nullptr;
- const TensorInfo* cellToOutputWeights = nullptr;
- const TensorInfo* inputLayerNormWeights = nullptr;
- const TensorInfo* forgetLayerNormWeights = nullptr;
- const TensorInfo* cellLayerNormWeights = nullptr;
- const TensorInfo* outputLayerNormWeights = nullptr;
+ LstmInputParamsInfo paramsInfo;
+
+ paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
+ paramsInfo.m_InputToCellWeights = &inputToCellWeights;
+ paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
+ paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ paramsInfo.m_ForgetGateBias = &forgetGateBias;
+ paramsInfo.m_CellBias = &cellBias;
+ paramsInfo.m_OutputGateBias = &outputGateBias;
+
+ // Optional parameters
TensorInfo optInputToInputWeights;
TensorInfo optRecurrentToInputWeights;
TensorInfo optCellToInputWeights;
@@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
{
optInputToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
- inputToInputWeights = &optInputToInputWeights;
+ paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
optRecurrentToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
- recurrentToInputWeights = &optRecurrentToInputWeights;
+ paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
{
optCellToInputWeights =
OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
- cellToInputWeights = &optCellToInputWeights;
+ paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
}
optInputGateBias =
OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
- inputGateBias = &optInputGateBias;
+ paramsInfo.m_InputGateBias = &optInputGateBias;
}
if(descriptor.m_ProjectionEnabled)
{
optProjectionWeights =
OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
- projectionWeights = &optProjectionWeights;
+ paramsInfo.m_ProjectionWeights = &optProjectionWeights;
if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
{
optProjectionBias =
OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
- projectionBias = &optProjectionBias;
+ paramsInfo.m_ProjectionBias = &optProjectionBias;
}
}
@@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
{
optCellToForgetWeights =
OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
- cellToForgetWeights = &optCellToForgetWeights;
+ paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
optCellToOutputWeights =
OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
- cellToOutputWeights = &optCellToOutputWeights;
+ paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
}
if(descriptor.m_LayerNormEnabled)
{
optInputLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
- inputLayerNormWeights = &optInputLayerNormWeights;
+ paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
optForgetLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
- forgetLayerNormWeights = &optForgetLayerNormWeights;
+ paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
optCellLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
- cellLayerNormWeights = &optCellLayerNormWeights;
+ paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
optOutputLayerNormWeights = OverrideDataType(
cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
- outputLayerNormWeights = &optOutputLayerNormWeights;
+ paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
}
result = layerSupportObject->IsLstmSupported(
@@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
cellStateOut,
output,
descriptor,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights,
- reason,
- inputLayerNormWeights,
- forgetLayerNormWeights,
- cellLayerNormWeights,
- outputLayerNormWeights);
+ paramsInfo,
+ reason);
break;
}
case LayerType::Maximum:
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 497a6435df..6d9b197679 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -405,28 +405,8 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported,
- const TensorInfo* inputLayerNormWeights,
- const TensorInfo* forgetLayerNormWeights,
- const TensorInfo* cellLayerNormWeights,
- const TensorInfo* outputLayerNormWeights) const
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
{
FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
reasonIfUnsupported,
@@ -438,23 +418,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
cellStateOut,
output,
descriptor,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights);
+ paramsInfo);
}
bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 4a55997004..63a4daf864 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -114,28 +114,8 @@ public:
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
- const TensorInfo* inputLayerNormWeights = nullptr,
- const TensorInfo* forgetLayerNormWeights = nullptr,
- const TensorInfo* cellLayerNormWeights = nullptr,
- const TensorInfo* outputLayerNormWeights = nullptr) const override;
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index f4d8974226..3dbbbc3784 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -224,22 +224,7 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights)
+ const LstmInputParamsInfo& paramsInfo)
{
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
@@ -253,18 +238,21 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
// Basic parameters
- const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
- const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
- const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+ const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+ const arm_compute::TensorInfo aclInputToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+ const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToCellWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToOutputWeights);
- const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
- const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
- const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+ const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+ const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+ const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
arm_compute::TensorInfo aclInputToInputWeightsInfo;
arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -277,43 +265,37 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
if (!descriptor.m_CifgEnabled)
{
- armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
- aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
- armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
- aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
+ aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+ aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
- if (cellToInputWeights != nullptr)
+ if (paramsInfo.m_CellToInputWeights != nullptr)
{
- armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
- aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+ aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
}
- armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
- aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+ aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
- cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+ paramsInfo.m_CellToInputWeights != nullptr ?
+ &aclCellToInputWeightsInfo: nullptr,
&aclInputGateBiasInfo);
}
if (descriptor.m_ProjectionEnabled)
{
- const armnn::TensorInfo& projectionWInfo = *projectionWeights;
- aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
+ aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
- if (projectionBias != nullptr)
+ if (paramsInfo.m_ProjectionBias != nullptr)
{
- const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
- aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+ aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
}
lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
- projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+ paramsInfo.m_ProjectionBias != nullptr ?
+ &aclProjectionBiasInfo: nullptr);
}
if (descriptor.m_PeepholeEnabled)
{
- const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
- aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
- const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
- aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+ aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+ aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
}
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index 6a0c41fae3..9a3211a037 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -49,20 +49,5 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor &descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights);
+ const LstmInputParamsInfo& paramsInfo);
} //namespace armnn
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index ac7f310c1d..59c14c4490 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported,
- const TensorInfo* inputLayerNormWeights,
- const TensorInfo* forgetLayerNormWeights,
- const TensorInfo* cellLayerNormWeights,
- const TensorInfo* outputLayerNormWeights) const
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- ignore_unused(inputToForgetWeights);
- ignore_unused(inputToCellWeights);
- ignore_unused(inputToOutputWeights);
- ignore_unused(recurrentToForgetWeights);
- ignore_unused(recurrentToCellWeights);
- ignore_unused(recurrentToOutputWeights);
- ignore_unused(forgetGateBias);
- ignore_unused(cellBias);
- ignore_unused(outputGateBias);
- ignore_unused(inputToInputWeights);
- ignore_unused(recurrentToInputWeights);
- ignore_unused(cellToInputWeights);
- ignore_unused(inputGateBias);
- ignore_unused(projectionWeights);
- ignore_unused(projectionBias);
- ignore_unused(cellToForgetWeights);
- ignore_unused(cellToOutputWeights);
- ignore_unused(inputLayerNormWeights);
- ignore_unused(forgetLayerNormWeights);
- ignore_unused(cellLayerNormWeights);
- ignore_unused(outputLayerNormWeights);
+ ignore_unused(paramsInfo);
bool supported = true;
@@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
DataType::QuantisedSymm16
};
+ // check inputs and outputs
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
"Reference Lstm: input is not a supported type.");
-
supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
"Reference Lstm: input and outputStateIn types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
"Reference Lstm: input and cellStateIn types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
"Reference Lstm: input and scratchBuffer types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
"Reference Lstm: input and outputStateOut types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
"Reference Lstm: input and cellStateOut types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference Lstm: input and output types are mismatched");
+ // check layer parameters
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToCellWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToOutputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and ForgetGateBias types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
+ "Reference Lstm: input and CellBias types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and OutputGateBias types are mismatched");
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToInputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and InputGateBias types are mismatched");
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and CellToInputWeights types are mismatched");
+ }
+ }
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and CellToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and CellToOutputWeights types are mismatched");
+ }
+ if (descriptor.m_ProjectionEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and mProjectionWeights types are mismatched");
+ if (paramsInfo.m_ProjectionBias != nullptr)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
+ "Reference Lstm: input and ProjectionBias types are mismatched");
+ }
+ }
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and InputLayerNormWeights types are mismatched");
+ }
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and CellLayerNormWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
+ }
return supported;
}
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index ead4d1ce4a..c0bf18824e 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -138,28 +138,8 @@ public:
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
- const TensorInfo* inputLayerNormWeights = nullptr,
- const TensorInfo* forgetLayerNormWeights = nullptr,
- const TensorInfo* cellLayerNormWeights = nullptr,
- const TensorInfo* outputLayerNormWeights = nullptr) const override;
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,