diff options
Diffstat (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 307 |
1 files changed, 307 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp new file mode 100644 index 0000000000..311fa18f91 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefUnidirectionalSequenceLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnnUtils/Permute.hpp> + +namespace armnn +{ + +RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefUnidirectionalSequenceLstmWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) const +{ + TensorInfo inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); + const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); + TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorShape& inputShape = inputInfo.GetShape(); + TensorShape& outputShape= outputInfo.GetShape(); + auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map()); + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute to time major + const PermutationVector& mappings = {1U, 0U, 2U}; + std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); + inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); + inputInfo.SetShape(inputShape); + armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); + + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + } + unsigned int maxTime = inputShape[0]; + unsigned int batchSize = inputShape[1]; + unsigned int outputSize = outputShape[2]; + unsigned int inputSize = inputShape[2]; + + TensorInfo scratchInfo = outputInfo; + scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); + + std::vector<float> inputGateScratchBuffer; + std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + + std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); + std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); + + void* outputStateOutData = outputStateOutBuffer.data(); + void* cellStateOutData = cellStateOutBuffer.data(); + + std::unique_ptr<Encoder<float>> inputGateScratch; + std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data()); + std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data()); + + std::unique_ptr<Decoder<float>> inputGateScratchDecoder; + std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo, + forgetGateScratchBuffer.data()); + std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo, + outputGateScratchBuffer.data()); + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; + + if (!useCifg) + { + inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); + inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data()); + inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data()); + } + + std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData); + std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData); + std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData); + + TensorInfo lstmInputInfo = inputInfo; + TensorShape batchInputShape = TensorShape({batchSize, inputSize}); + lstmInputInfo.SetShape(batchInputShape); + + TensorInfo lstmOutputInfo = outputInfo; + lstmOutputInfo.SetShape({batchSize, outputSize}); + + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + unsigned int nOutput = recurrentToOutputWeightsShape[1]; + auto outputStateInData = inputs[1]->Map(); + std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData); + + auto cellStateInData = inputs[2]->Map(); + std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData); + + auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map()); + std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); + auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map()); + std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); + std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); + + std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; + std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; + std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> inputGateBiasTensor; + std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; + std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; + std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; + + std::unique_ptr<Decoder<float>> projectionWeightsTensor; + std::unique_ptr<Decoder<float>> projectionBiasTensor; + + std::unique_ptr<Decoder<float>> inputLayerNormWeights; + std::unique_ptr<Decoder<float>> forgetLayerNormWeights; + std::unique_ptr<Decoder<float>> cellLayerNormWeights; + std::unique_ptr<Decoder<float>> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder<float>( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>()); + } + forgetLayerNormWeights = MakeDecoder<float>( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>()); + cellLayerNormWeights = MakeDecoder<float>( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>()); + outputLayerNormWeights = MakeDecoder<float>( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>()); + } + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder<float>( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); + inputGateBiasTensor = MakeDecoder<float>( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>()); + recurrentToInputWeightsTensor = MakeDecoder<float>( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder<float>( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); + cellToOutputWeightsTensor = MakeDecoder<float>( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder<float>( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder<float>( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder<float>( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); + } + } + + unsigned int batchInputSize = batchSize * inputSize; + unsigned int batchOutputSize = batchSize * nOutput; + + for (unsigned int t = 0; t < maxTime; ++t) + { + LstmImpl(m_Data.m_Parameters, + lstmInputInfo, + lstmOutputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); + + currentInputData += batchInputSize; + inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); + currentOutputData += batchOutputSize; + output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); + outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); + + // Assign output state out to the next output state in + outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData); + + // Assign cell state out to the next cell state in + cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData); + } + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute Output back to batch major + const PermutationVector& mappings = {1U, 0U, 2U}; + auto outputData = reinterpret_cast<float*>(outputs[0]->Map()); + std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements()); + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); + } +} + +} //namespace armnn |