aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-08 15:56:59 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 15:08:10 +0000
commita2ec9092f0bff018bfe7ae0cacb7e30bcc17c1c7 (patch)
tree11a36098492aed1629d873dcc02fb25a0071de2b
parentc0ed7baa8c05c4710034dfd179fadd31b716a46f (diff)
downloadarmnn-a2ec9092f0bff018bfe7ae0cacb7e30bcc17c1c7.tar.gz
IVGCVSW-3338 Add CL backend support for LSTM normalization
* Enable calls to LSTM normalization unit tests on CL backend. * Update CL workload to set the layer normalization parameters. !android-nn-driver:1461 Change-Id: Ia5a29918961c391c1f1d8f331add377a38822ddd Signed-off-by: Francis Murtagh <francis.murtagh@arm.com> Signed-off-by: Jan Eilers <jan.eilers@arm.com>
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp3
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.cpp77
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.hpp4
3 files changed, 76 insertions, 8 deletions
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index ac96bf8135..5575a05b99 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -354,6 +354,9 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgNoPeepholeNoProjection,
ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
+ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
+ LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
+
// Convert from Float16 to Float32
ARMNN_AUTO_TEST_CASE(SimpleConvertFp16ToFp32, SimpleConvertFp16ToFp32Test)
// Convert from Float32 to Float16
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index 3dbbbc3784..f5d081e778 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -100,6 +100,28 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+ }
+ BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+ BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+ BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+ lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? nullptr :
+ m_InputLayerNormWeightsTensor.get(),
+ m_ForgetLayerNormWeightsTensor.get(),
+ m_CellLayerNormWeightsTensor.get(),
+ m_OutputLayerNormWeightsTensor.get());
+ }
+
const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
const arm_compute::ICLTensor& output_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
const arm_compute::ICLTensor& cell_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
@@ -161,7 +183,6 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
throw armnn::Exception("Wrong Type of Activation Function!");
}
-
m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
@@ -172,15 +193,15 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
- InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
- InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
- InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+ InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+ InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+ InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
- InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
- InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
- InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
- InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
+ InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+ InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+ InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
if (!m_Data.m_Parameters.m_CifgEnabled)
{
@@ -208,6 +229,18 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
+ }
+
+ InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+ InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
+ InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+ }
+
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
m_LstmLayer.prepare();
@@ -262,6 +295,10 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
arm_compute::TensorInfo aclProjectionBiasInfo;
arm_compute::TensorInfo aclCellToForgetWeightsInfo;
arm_compute::TensorInfo aclCellToOutputWeightsInfo;
+ arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
if (!descriptor.m_CifgEnabled)
{
@@ -333,6 +370,26 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
throw armnn::Exception("Wrong Type of Activation Function!");
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+ }
+
+ aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+
+ aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+
+ aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+ lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+ nullptr : &aclInputLayerNormWeightsInfo,
+ &aclForgetLayerNormWeightsInfo,
+ &aclCellLayerNormWeightsInfo,
+ &aclOutputLayerNormWeightsInfo);
+ }
+
return arm_compute::CLLSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo,
&aclInputToCellWeightsInfo,
&aclInputToOutputWeightsInfo,
@@ -369,6 +426,10 @@ void ClLstmFloatWorkload::FreeUnusedTensors()
FreeTensorIfUnused(m_ProjectionWeightsTensor);
FreeTensorIfUnused(m_ProjectionBiasTensor);
FreeTensorIfUnused(m_ScratchBuffer);
+ FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
}
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index 9a3211a037..5bd67c256f 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -39,6 +39,10 @@ private:
std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
+ std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor;
std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;