From e5339e7013cf24e5a34509fb0a60377e5f8a244e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 28 Jul 2021 17:33:28 +0100 Subject: MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport * Add RefUnidirectionalSequenceLstmWorkload * Refactor Lstm to be able to use for Lstm and SequenceLstm * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c --- .../RefUnidirectionalSequenceLstmWorkload.cpp | 307 +++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp') diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp new file mode 100644 index 0000000000..311fa18f91 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefUnidirectionalSequenceLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +namespace armnn +{ + +RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefUnidirectionalSequenceLstmWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector inputs, + std::vector outputs) const +{ + TensorInfo inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); + const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); + TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorShape& inputShape = inputInfo.GetShape(); + TensorShape& outputShape= outputInfo.GetShape(); + auto inputTensor = reinterpret_cast(inputs[0]->Map()); + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute to time major + const PermutationVector& mappings = {1U, 0U, 2U}; + std::vector inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); + inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); + inputInfo.SetShape(inputShape); + armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); + + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + } + unsigned int maxTime = inputShape[0]; + unsigned int batchSize = inputShape[1]; + unsigned int outputSize = outputShape[2]; + unsigned int inputSize = inputShape[2]; + + TensorInfo scratchInfo = outputInfo; + scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); + + std::vector inputGateScratchBuffer; + std::vector cellScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + + std::vector outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); + std::vector cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); + + void* outputStateOutData = outputStateOutBuffer.data(); + void* cellStateOutData = cellStateOutBuffer.data(); + + std::unique_ptr> inputGateScratch; + std::unique_ptr> cellScratch = MakeEncoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratch = MakeEncoder(scratchInfo, forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratch = MakeEncoder(scratchInfo, outputGateScratchBuffer.data()); + + std::unique_ptr> inputGateScratchDecoder; + std::unique_ptr> cellScratchDecoder = MakeDecoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratchDecoder = MakeDecoder(scratchInfo, + forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratchDecoder = MakeDecoder(scratchInfo, + outputGateScratchBuffer.data()); + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; + + if (!useCifg) + { + inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); + inputGateScratch = MakeEncoder(scratchInfo, inputGateScratchBuffer.data()); + inputGateScratchDecoder = MakeDecoder(scratchInfo, inputGateScratchBuffer.data()); + } + + std::unique_ptr> outputStateOut = MakeEncoder(outputStateInfo, outputStateOutData); + std::unique_ptr> cellStateOut = MakeEncoder(cellStateInfo, cellStateOutData); + std::unique_ptr> cellStateOutDecoder = MakeDecoder(cellStateInfo, cellStateOutData); + + TensorInfo lstmInputInfo = inputInfo; + TensorShape batchInputShape = TensorShape({batchSize, inputSize}); + lstmInputInfo.SetShape(batchInputShape); + + TensorInfo lstmOutputInfo = outputInfo; + lstmOutputInfo.SetShape({batchSize, outputSize}); + + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + unsigned int nOutput = recurrentToOutputWeightsShape[1]; + auto outputStateInData = inputs[1]->Map(); + std::unique_ptr> outputStateIn = MakeDecoder(outputStateInfo, outputStateInData); + + auto cellStateInData = inputs[2]->Map(); + std::unique_ptr> cellStateIn = MakeDecoder(cellStateInfo, cellStateInData); + + auto currentInputData = reinterpret_cast(inputs[0]->Map()); + std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); + auto currentOutputData = reinterpret_cast(outputs[0]->Map()); + std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); + std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + std::unique_ptr> inputToInputWeightsTensor; + std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> recurrentToInputWeightsTensor; + std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> inputGateBiasTensor; + std::unique_ptr> forgetGateBiasTensor = MakeDecoder( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor()); + std::unique_ptr> cellBiasTensor = MakeDecoder( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor()); + std::unique_ptr> outputGateBiasTensor = MakeDecoder( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor()); + + std::unique_ptr> cellToInputWeightsTensor; + std::unique_ptr> cellToForgetWeightsTensor; + std::unique_ptr> cellToOutputWeightsTensor; + + std::unique_ptr> projectionWeightsTensor; + std::unique_ptr> projectionBiasTensor; + + std::unique_ptr> inputLayerNormWeights; + std::unique_ptr> forgetLayerNormWeights; + std::unique_ptr> cellLayerNormWeights; + std::unique_ptr> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor()); + } + forgetLayerNormWeights = MakeDecoder( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor()); + cellLayerNormWeights = MakeDecoder( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor()); + outputLayerNormWeights = MakeDecoder( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor()); + } + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor()); + inputGateBiasTensor = MakeDecoder( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor()); + recurrentToInputWeightsTensor = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor()); + cellToOutputWeightsTensor = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor()); + } + } + + unsigned int batchInputSize = batchSize * inputSize; + unsigned int batchOutputSize = batchSize * nOutput; + + for (unsigned int t = 0; t < maxTime; ++t) + { + LstmImpl(m_Data.m_Parameters, + lstmInputInfo, + lstmOutputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); + + currentInputData += batchInputSize; + inputData = MakeDecoder(lstmInputInfo, currentInputData); + currentOutputData += batchOutputSize; + output = MakeEncoder(lstmOutputInfo, currentOutputData); + outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + // Assign output state out to the next output state in + outputStateIn = MakeDecoder(outputStateInfo, outputStateOutData); + + // Assign cell state out to the next cell state in + cellStateIn = MakeDecoder(cellStateInfo, cellStateOutData); + } + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute Output back to batch major + const PermutationVector& mappings = {1U, 0U, 2U}; + auto outputData = reinterpret_cast(outputs[0]->Map()); + std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); + } +} + +} //namespace armnn -- cgit v1.2.1