// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "RefLstmWorkload.hpp" #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" #include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" namespace armnn { RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) : BaseWorkload(descriptor, info) , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) {} void RefLstmWorkload::Execute() const { Execute(m_Data.m_Inputs, m_Data.m_Outputs); } void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) { Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); } void RefLstmWorkload::Execute(std::vector inputs, std::vector outputs) const { // This is a porting of the LSTM::Eval() method in the Android code base // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); const TensorShape& inputShape = inputInfo.GetShape(); std::unique_ptr> outputStateOut = MakeEncoder(outputInfo, outputs[1]->Map()); std::unique_ptr> cellStateOut = MakeEncoder(outputInfo, outputs[2]->Map()); std::unique_ptr> output = MakeEncoder(outputInfo, outputs[3]->Map()); std::unique_ptr> cellStateOutDecoder = MakeDecoder(outputInfo, outputs[2]->Map()); std::unique_ptr> outputDecoder = MakeDecoder(outputInfo, outputs[3]->Map()); std::unique_ptr> inputData = MakeDecoder(inputInfo, inputs[0]->Map()); std::unique_ptr> outputStateIn = MakeDecoder(inputInfo, inputs[1]->Map()); std::unique_ptr> cellStateIn = MakeDecoder(inputInfo, inputs[2]->Map()); const uint32_t nBatch = inputShape[0]; const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; // Index the scratch buffers pointers to the global scratch buffer. std::unique_ptr> inputGateScratch = MakeEncoder(outputInfo, outputs[0]->Map()); std::unique_ptr> cellScratch = MakeEncoder(outputInfo, outputs[0]->Map()); std::unique_ptr> forgetGateScratch = MakeEncoder(outputInfo, outputs[0]->Map()); std::unique_ptr> outputGateScratch = MakeEncoder(outputInfo, outputs[0]->Map()); std::unique_ptr> inputGateScratchDecoder = MakeDecoder(outputInfo, outputs[0]->Map()); std::unique_ptr> cellScratchDecoder = MakeDecoder(outputInfo, outputs[0]->Map()); std::unique_ptr> forgetGateScratchDecoder = MakeDecoder(outputInfo, outputs[0]->Map()); std::unique_ptr> outputGateScratchDecoder = MakeDecoder(outputInfo, outputs[0]->Map()); if (useCifg) { *cellScratch += (0 * nCell * nBatch); *forgetGateScratch += (1 * nCell * nBatch); *outputGateScratch += (2 * nCell * nBatch); *cellScratchDecoder += (0 * nCell * nBatch); *forgetGateScratchDecoder += (1 * nCell * nBatch); *outputGateScratchDecoder += (2 * nCell * nBatch); } else { *inputGateScratch += (0 * nCell * nBatch); *cellScratch += (1 * nCell * nBatch); *forgetGateScratch += (2 * nCell * nBatch); *outputGateScratch += (3 * nCell * nBatch); *inputGateScratchDecoder += (0 * nCell * nBatch); *cellScratchDecoder += (1 * nCell * nBatch); *forgetGateScratchDecoder += (2 * nCell * nBatch); *outputGateScratchDecoder += (3 * nCell * nBatch); } std::unique_ptr> inputToInputWeightsTensor; std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor()); std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor()); std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToInputWeightsTensor; std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor()); std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor()); std::unique_ptr> inputGateBiasTensor; std::unique_ptr> forgetGateBiasTensor = MakeDecoder( m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor()); std::unique_ptr> cellBiasTensor = MakeDecoder( m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor()); std::unique_ptr> outputGateBiasTensor = MakeDecoder( m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor()); std::unique_ptr> cellToInputWeightsTensor; std::unique_ptr> cellToForgetWeightsTensor; std::unique_ptr> cellToOutputWeightsTensor; std::unique_ptr> projectionWeightsTensor; std::unique_ptr> projectionBiasTensor; std::unique_ptr> inputLayerNormWeights; std::unique_ptr> forgetLayerNormWeights; std::unique_ptr> cellLayerNormWeights; std::unique_ptr> outputLayerNormWeights; const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); if (useLayerNorm) { if (!useCifg) { inputLayerNormWeights = MakeDecoder( m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor()); } forgetLayerNormWeights = MakeDecoder( m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor()); cellLayerNormWeights = MakeDecoder( m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor()); outputLayerNormWeights = MakeDecoder( m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor()); } if (!useCifg) { inputToInputWeightsTensor = MakeDecoder( m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor()); inputGateBiasTensor = MakeDecoder( m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor()); recurrentToInputWeightsTensor = MakeDecoder( m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor()); } if (usePeephole) { cellToForgetWeightsTensor = MakeDecoder( m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor()); cellToOutputWeightsTensor = MakeDecoder( m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor()); } if (!useCifg && usePeephole) { cellToInputWeightsTensor = MakeDecoder( m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor()); } if (m_Data.m_Parameters.m_ProjectionEnabled) { projectionWeightsTensor = MakeDecoder( m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor()); if (m_ProjectionBiasTensor) { projectionBiasTensor = MakeDecoder( m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor()); } } LstmImpl(m_Data.m_Parameters, inputInfo, outputInfo, inputToOutputWeightsShape, recurrentToOutputWeightsShape, inputData, outputStateIn, cellStateIn, outputStateOut, cellStateOut, output, cellStateOutDecoder, outputDecoder, inputToInputWeightsTensor, inputToForgetWeightsTensor, inputToCellWeightsTensor, inputToOutputWeightsTensor, recurrentToInputWeightsTensor, recurrentToForgetWeightsTensor, recurrentToCellWeightsTensor, recurrentToOutputWeightsTensor, cellToInputWeightsTensor, cellToForgetWeightsTensor, cellToOutputWeightsTensor, inputGateBiasTensor, forgetGateBiasTensor, cellBiasTensor, outputGateBiasTensor, projectionWeightsTensor, projectionBiasTensor, inputLayerNormWeights, forgetLayerNormWeights, cellLayerNormWeights, outputLayerNormWeights, inputGateScratch, cellScratch, forgetGateScratch, outputGateScratch, inputGateScratchDecoder, cellScratchDecoder, forgetGateScratchDecoder, outputGateScratchDecoder, m_LayerNormEpsilon); } } //namespace armnn