aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-08 09:57:55 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-10 09:15:04 +0000
commitad5293a86e315049de36afd723dcd1a7e70681a7 (patch)
treeb9003cd1fba00c267a971d899284b3fcbd5ce6f5 /src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
parent8b797a84f1e8f9d1d5d064afbc4fc12c21b8ffed (diff)
downloadarmnn-ad5293a86e315049de36afd723dcd1a7e70681a7.tar.gz
IVGCVSW-3337 Add Neon backend support for LSTM layer normalisation
* Update neon lstm workload * Add unit tests * Add isLstmSupported Change-Id: I493c159137f6544b0f2532d16d4fafd7a7e587e5 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.cpp148
1 files changed, 97 insertions, 51 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
index c7f5f090ce..6dd9f4f698 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
@@ -97,6 +97,30 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+ }
+
+ m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+
+ m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+
+ m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+ lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ?
+ nullptr : m_InputLayerNormWeightsTensor.get(),
+ m_ForgetLayerNormWeightsTensor.get(),
+ m_CellLayerNormWeightsTensor.get(),
+ m_OutputLayerNormWeightsTensor.get());
+ }
+
const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
const arm_compute::ITensor& cell_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
@@ -113,13 +137,13 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
if (m_Data.m_Parameters.m_CifgEnabled)
{
- // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
+ // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 3 }, DataType::Float32);
BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
}
else
{
- // scratch_buffer [num_units * 3, batch_size] without CIFG
+ // scratch_buffer [num_units * 4, batch_size] without CIFG
armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 4 }, DataType::Float32);
BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
}
@@ -222,6 +246,17 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
m_Data.m_CellToOutputWeights);
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
+ }
+ InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+ InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
+ InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+ }
+
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
m_LstmLayer.prepare();
@@ -241,27 +276,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights)
+ const LstmInputParamsInfo& paramsInfo)
{
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
- // The inputs and the outputs
+ // The inputs and outputs
const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
@@ -271,18 +290,24 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
// Basic parameters
- const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
- const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
- const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+ const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+ const arm_compute::TensorInfo aclInputToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+ const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToCellWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToOutputWeights);
- const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
- const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
- const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+ const arm_compute::TensorInfo aclForgetGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+ const arm_compute::TensorInfo aclCellBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+ const arm_compute::TensorInfo aclOutputGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
arm_compute::TensorInfo aclInputToInputWeightsInfo;
arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -293,48 +318,65 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
arm_compute::TensorInfo aclCellToForgetWeightsInfo;
arm_compute::TensorInfo aclCellToOutputWeightsInfo;
+ arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
+
+
if (!descriptor.m_CifgEnabled)
{
- armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
- aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
- armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
- aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
-
- if (cellToInputWeights != nullptr)
+ if (descriptor.m_PeepholeEnabled)
{
- armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
- aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+ aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
}
- armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
- aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+ aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+ aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+ aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+
lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
- cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+ descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
&aclInputGateBiasInfo);
}
if (descriptor.m_ProjectionEnabled)
{
- const armnn::TensorInfo& projectionWInfo = *projectionWeights;
- aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
-
- if (projectionBias != nullptr)
+ if (paramsInfo.m_ProjectionBias != nullptr)
{
- const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
- aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+ aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias());
}
+ aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
+
lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
- projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+ paramsInfo.m_ProjectionBias != nullptr ?
+ &aclProjectionBiasInfo : nullptr);
}
if (descriptor.m_PeepholeEnabled)
{
- const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
- aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
- const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
- aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+ aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+ aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
+
lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+ }
+ aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+ aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+ aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+ lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+ nullptr : &aclInputLayerNormWeightsInfo,
+ &aclForgetLayerNormWeightsInfo,
+ &aclCellLayerNormWeightsInfo,
+ &aclOutputLayerNormWeightsInfo);
+ }
+
float cell_threshold = descriptor.m_ClippingThresCell;
float projection_threshold = descriptor.m_ClippingThresProj;
@@ -407,6 +449,10 @@ void NeonLstmFloatWorkload::FreeUnusedTensors()
FreeTensorIfUnused(m_ProjectionWeightsTensor);
FreeTensorIfUnused(m_ProjectionBiasTensor);
FreeTensorIfUnused(m_ScratchBuffer);
+ FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
}
} //namespace armnn