diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonLstmFloatWorkload.cpp | 148 |
1 files changed, 97 insertions, 51 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp index c7f5f090ce..6dd9f4f698 100644 --- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp @@ -97,6 +97,30 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get()); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>(); + if (!m_Data.m_Parameters.m_CifgEnabled) + { + BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo()); + } + + m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo()); + + m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo()); + + m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo()); + + lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? + nullptr : m_InputLayerNormWeightsTensor.get(), + m_ForgetLayerNormWeightsTensor.get(), + m_CellLayerNormWeightsTensor.get(), + m_OutputLayerNormWeightsTensor.get()); + } + const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); const arm_compute::ITensor& cell_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor(); @@ -113,13 +137,13 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript m_ScratchBuffer = std::make_unique<arm_compute::Tensor>(); if (m_Data.m_Parameters.m_CifgEnabled) { - // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG + // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 3 }, DataType::Float32); BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1); } else { - // scratch_buffer [num_units * 3, batch_size] without CIFG + // scratch_buffer [num_units * 4, batch_size] without CIFG armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 4 }, DataType::Float32); BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2); } @@ -222,6 +246,17 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript m_Data.m_CellToOutputWeights); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights); + } + InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights); + InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights); + InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights); + } + // Force Compute Library to perform the necessary copying and reshaping, after which // delete all the input tensors that will no longer be needed m_LstmLayer.prepare(); @@ -241,27 +276,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights) + const LstmInputParamsInfo& paramsInfo) { arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info; - // The inputs and the outputs + // The inputs and outputs const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn); @@ -271,18 +290,24 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); // Basic parameters - const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights); - const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights); - const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo - = BuildArmComputeTensorInfo(recurrentToForgetWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo - = BuildArmComputeTensorInfo(recurrentToCellWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo - = BuildArmComputeTensorInfo(recurrentToOutputWeights); - const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias); - const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias); - const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + const arm_compute::TensorInfo aclForgetGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); arm_compute::TensorInfo aclInputToInputWeightsInfo; arm_compute::TensorInfo aclRecurrentToInputWeightsInfo; @@ -293,48 +318,65 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, arm_compute::TensorInfo aclCellToForgetWeightsInfo; arm_compute::TensorInfo aclCellToOutputWeightsInfo; + arm_compute::TensorInfo aclInputLayerNormWeightsInfo; + arm_compute::TensorInfo aclForgetLayerNormWeightsInfo; + arm_compute::TensorInfo aclCellLayerNormWeightsInfo; + arm_compute::TensorInfo aclOutputLayerNormWeightsInfo; + + if (!descriptor.m_CifgEnabled) { - armnn::TensorInfo inputToInputWInfo = *inputToInputWeights; - aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo); - armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights; - aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo); - - if (cellToInputWeights != nullptr) + if (descriptor.m_PeepholeEnabled) { - armnn::TensorInfo cellToInputWInfo = *cellToInputWeights; - aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo); + aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights()); } - armnn::TensorInfo inputGateBiasInfo = *inputGateBias; - aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo); + aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); + aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); + lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo, - cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr, + descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr, &aclInputGateBiasInfo); } if (descriptor.m_ProjectionEnabled) { - const armnn::TensorInfo& projectionWInfo = *projectionWeights; - aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo); - - if (projectionBias != nullptr) + if (paramsInfo.m_ProjectionBias != nullptr) { - const armnn::TensorInfo& projectionBiasInfo = *projectionBias; - aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo); + aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias()); } + aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights()); + lstm_params_info.set_projection_params(&aclProjectionWeightsInfo, - projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr); + paramsInfo.m_ProjectionBias != nullptr ? + &aclProjectionBiasInfo : nullptr); } if (descriptor.m_PeepholeEnabled) { - const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights; - aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo); - const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights; - aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo); + aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights()); + aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights()); + lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo); } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights()); + } + aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights()); + aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights()); + aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights()); + + lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? + nullptr : &aclInputLayerNormWeightsInfo, + &aclForgetLayerNormWeightsInfo, + &aclCellLayerNormWeightsInfo, + &aclOutputLayerNormWeightsInfo); + } + float cell_threshold = descriptor.m_ClippingThresCell; float projection_threshold = descriptor.m_ClippingThresProj; @@ -407,6 +449,10 @@ void NeonLstmFloatWorkload::FreeUnusedTensors() FreeTensorIfUnused(m_ProjectionWeightsTensor); FreeTensorIfUnused(m_ProjectionBiasTensor); FreeTensorIfUnused(m_ScratchBuffer); + FreeTensorIfUnused(m_InputLayerNormWeightsTensor); + FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor); + FreeTensorIfUnused(m_CellLayerNormWeightsTensor); + FreeTensorIfUnused(m_OutputLayerNormWeightsTensor); } } //namespace armnn |