// // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "RefUnidirectionalSequenceLstmWorkload.hpp" #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" #include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" #include namespace armnn { RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( const UnidirectionalSequenceLstmQueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) {} void RefUnidirectionalSequenceLstmWorkload::Execute() const { Execute(m_Data.m_Inputs, m_Data.m_Outputs); } void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) { Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); } void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector inputs, std::vector outputs) const { TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); TensorInfo outputInfo = GetTensorInfo(outputs[0]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast(inputs[0]->Map()); if (!m_Data.m_Parameters.m_TimeMajor) { // Permute to time major const PermutationVector& mappings = {1U, 0U, 2U}; std::vector inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); inputInfo.SetShape(inputShape); armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); } unsigned int maxTime = inputShape[0]; unsigned int batchSize = inputShape[1]; unsigned int outputSize = outputShape[2]; unsigned int inputSize = inputShape[2]; TensorInfo scratchInfo = outputInfo; scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); std::vector inputGateScratchBuffer; std::vector cellScratchBuffer(scratchInfo.GetNumElements(), 0.); std::vector forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); std::vector outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); std::vector outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); std::vector cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); void* outputStateOutData = outputStateOutBuffer.data(); void* cellStateOutData = cellStateOutBuffer.data(); std::unique_ptr> inputGateScratch; std::unique_ptr> cellScratch = MakeEncoder(scratchInfo, cellScratchBuffer.data()); std::unique_ptr> forgetGateScratch = MakeEncoder(scratchInfo, forgetGateScratchBuffer.data()); std::unique_ptr> outputGateScratch = MakeEncoder(scratchInfo, outputGateScratchBuffer.data()); std::unique_ptr> inputGateScratchDecoder; std::unique_ptr> cellScratchDecoder = MakeDecoder(scratchInfo, cellScratchBuffer.data()); std::unique_ptr> forgetGateScratchDecoder = MakeDecoder(scratchInfo, forgetGateScratchBuffer.data()); std::unique_ptr> outputGateScratchDecoder = MakeDecoder(scratchInfo, outputGateScratchBuffer.data()); const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; if (!useCifg) { inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); inputGateScratch = MakeEncoder(scratchInfo, inputGateScratchBuffer.data()); inputGateScratchDecoder = MakeDecoder(scratchInfo, inputGateScratchBuffer.data()); } std::unique_ptr> outputStateOut = MakeEncoder(outputStateInfo, outputStateOutData); std::unique_ptr> cellStateOut = MakeEncoder(cellStateInfo, cellStateOutData); std::unique_ptr> cellStateOutDecoder = MakeDecoder(cellStateInfo, cellStateOutData); TensorInfo lstmInputInfo = inputInfo; TensorShape batchInputShape = TensorShape({batchSize, inputSize}); lstmInputInfo.SetShape(batchInputShape); TensorInfo lstmOutputInfo = outputInfo; lstmOutputInfo.SetShape({batchSize, outputSize}); const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); unsigned int nOutput = recurrentToOutputWeightsShape[1]; auto outputStateInData = inputs[1]->Map(); std::unique_ptr> outputStateIn = MakeDecoder(outputStateInfo, outputStateInData); auto cellStateInData = inputs[2]->Map(); std::unique_ptr> cellStateIn = MakeDecoder(cellStateInfo, cellStateInData); auto currentInputData = reinterpret_cast(inputs[0]->Map()); std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); auto currentOutputData = reinterpret_cast(outputs[0]->Map()); std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); std::unique_ptr> inputToInputWeightsTensor; std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor()); std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor()); std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToInputWeightsTensor; std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor()); std::unique_ptr> inputGateBiasTensor; std::unique_ptr> forgetGateBiasTensor = MakeDecoder( m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor()); std::unique_ptr> cellBiasTensor = MakeDecoder( m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor()); std::unique_ptr> outputGateBiasTensor = MakeDecoder( m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor()); std::unique_ptr> cellToInputWeightsTensor; std::unique_ptr> cellToForgetWeightsTensor; std::unique_ptr> cellToOutputWeightsTensor; std::unique_ptr> projectionWeightsTensor; std::unique_ptr> projectionBiasTensor; std::unique_ptr> inputLayerNormWeights; std::unique_ptr> forgetLayerNormWeights; std::unique_ptr> cellLayerNormWeights; std::unique_ptr> outputLayerNormWeights; if (useLayerNorm) { if (!useCifg) { inputLayerNormWeights = MakeDecoder( m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor()); } forgetLayerNormWeights = MakeDecoder( m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor()); cellLayerNormWeights = MakeDecoder( m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor()); outputLayerNormWeights = MakeDecoder( m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor()); } if (!useCifg) { inputToInputWeightsTensor = MakeDecoder( m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor()); inputGateBiasTensor = MakeDecoder( m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor()); recurrentToInputWeightsTensor = MakeDecoder( m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor()); } if (usePeephole) { cellToForgetWeightsTensor = MakeDecoder( m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor()); cellToOutputWeightsTensor = MakeDecoder( m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor()); } if (!useCifg && usePeephole) { cellToInputWeightsTensor = MakeDecoder( m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor()); } if (m_Data.m_Parameters.m_ProjectionEnabled) { projectionWeightsTensor = MakeDecoder( m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor()); if (m_ProjectionBiasTensor) { projectionBiasTensor = MakeDecoder( m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor()); } } unsigned int batchInputSize = batchSize * inputSize; unsigned int batchOutputSize = batchSize * nOutput; for (unsigned int t = 0; t < maxTime; ++t) { LstmImpl(m_Data.m_Parameters, lstmInputInfo, lstmOutputInfo, inputToOutputWeightsShape, recurrentToOutputWeightsShape, inputData, outputStateIn, cellStateIn, outputStateOut, cellStateOut, output, cellStateOutDecoder, outputDecoder, inputToInputWeightsTensor, inputToForgetWeightsTensor, inputToCellWeightsTensor, inputToOutputWeightsTensor, recurrentToInputWeightsTensor, recurrentToForgetWeightsTensor, recurrentToCellWeightsTensor, recurrentToOutputWeightsTensor, cellToInputWeightsTensor, cellToForgetWeightsTensor, cellToOutputWeightsTensor, inputGateBiasTensor, forgetGateBiasTensor, cellBiasTensor, outputGateBiasTensor, projectionWeightsTensor, projectionBiasTensor, inputLayerNormWeights, forgetLayerNormWeights, cellLayerNormWeights, outputLayerNormWeights, inputGateScratch, cellScratch, forgetGateScratch, outputGateScratch, inputGateScratchDecoder, cellScratchDecoder, forgetGateScratchDecoder, outputGateScratchDecoder, m_LayerNormEpsilon); currentInputData += batchInputSize; inputData = MakeDecoder(lstmInputInfo, currentInputData); currentOutputData += batchOutputSize; output = MakeEncoder(lstmOutputInfo, currentOutputData); outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); // Assign output state out to the next output state in outputStateIn = MakeDecoder(outputStateInfo, outputStateOutData); // Assign cell state out to the next cell state in cellStateIn = MakeDecoder(cellStateInfo, cellStateOutData); } if (!m_Data.m_Parameters.m_TimeMajor) { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; auto outputData = reinterpret_cast(outputs[0]->Map()); std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); } } } //namespace armnn