From e5339e7013cf24e5a34509fb0a60377e5f8a244e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 28 Jul 2021 17:33:28 +0100 Subject: MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport * Add RefUnidirectionalSequenceLstmWorkload * Refactor Lstm to be able to use for Lstm and SequenceLstm * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c --- src/backends/reference/workloads/CMakeLists.txt | 4 + src/backends/reference/workloads/Lstm.cpp | 259 +++++++++++++++++ src/backends/reference/workloads/Lstm.hpp | 61 ++++ .../reference/workloads/RefLstmWorkload.cpp | 235 ++++------------ .../RefUnidirectionalSequenceLstmWorkload.cpp | 307 +++++++++++++++++++++ .../RefUnidirectionalSequenceLstmWorkload.hpp | 56 ++++ src/backends/reference/workloads/RefWorkloads.hpp | 1 + 7 files changed, 735 insertions(+), 188 deletions(-) create mode 100644 src/backends/reference/workloads/Lstm.cpp create mode 100644 src/backends/reference/workloads/Lstm.hpp create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 7a769e5246..b9f477cb6d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -42,6 +42,8 @@ list(APPEND armnnRefBackendWorkloads_sources Log.hpp LogSoftmax.cpp LogSoftmax.hpp + Lstm.cpp + Lstm.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -162,6 +164,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefTransposeConvolution2dWorkload.hpp RefTransposeWorkload.cpp RefTransposeWorkload.hpp + RefUnidirectionalSequenceLstmWorkload.cpp + RefUnidirectionalSequenceLstmWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp Resize.cpp diff --git a/src/backends/reference/workloads/Lstm.cpp b/src/backends/reference/workloads/Lstm.cpp new file mode 100644 index 0000000000..c1fb2bf4aa --- /dev/null +++ b/src/backends/reference/workloads/Lstm.cpp @@ -0,0 +1,259 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "Activation.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr>& inputData, + std::unique_ptr>& outputStateIn, + std::unique_ptr>& cellStateIn, + std::unique_ptr>& outputStateOut, + std::unique_ptr>& cellStateOut, + std::unique_ptr>& output, + std::unique_ptr>& cellStateOutDecoder, + std::unique_ptr>& outputDecoder, + std::unique_ptr>& inputToInputWeightsTensor, + std::unique_ptr>& inputToForgetWeightsTensor, + std::unique_ptr>& inputToCellWeightsTensor, + std::unique_ptr>& inputToOutputWeightsTensor, + std::unique_ptr>& recurrentToInputWeightsTensor, + std::unique_ptr>& recurrentToForgetWeightsTensor, + std::unique_ptr>& recurrentToCellWeightsTensor, + std::unique_ptr>& recurrentToOutputWeightsTensor, + std::unique_ptr>& cellToInputWeightsTensor, + std::unique_ptr>& cellToForgetWeightsTensor, + std::unique_ptr>& cellToOutputWeightsTensor, + std::unique_ptr>& inputGateBiasTensor, + std::unique_ptr>& forgetGateBiasTensor, + std::unique_ptr>& cellBiasTensor, + std::unique_ptr>& outputGateBiasTensor, + std::unique_ptr>& projectionWeightsTensor, + std::unique_ptr>& projectionBiasTensor, + std::unique_ptr>& inputLayerNormWeights, + std::unique_ptr>& forgetLayerNormWeights, + std::unique_ptr>& cellLayerNormWeights, + std::unique_ptr>& outputLayerNormWeights, + std::unique_ptr>& inputGateScratch, + std::unique_ptr>& cellScratch, + std::unique_ptr>& forgetGateScratch, + std::unique_ptr>& outputGateScratch, + std::unique_ptr>& inputGateScratchDecoder, + std::unique_ptr>& cellScratchDecoder, + std::unique_ptr>& forgetGateScratchDecoder, + std::unique_ptr>& outputGateScratchDecoder, + float layerNormEpsilon) +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = inputToOutputWeightsShape[0]; + const uint32_t nOutput = recurrentToOutputWeightsShape[1]; + + const bool useCifg = descriptor.m_CifgEnabled; + const bool usePeephole = descriptor.m_PeepholeEnabled; + const bool useLayerNorm = descriptor.m_LayerNormEnabled; + + if (!useLayerNorm) + { + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + } + else + { + // Initialize scratch buffers with zeroes. + if (!useCifg) + { + ZeroVector(*inputGateScratch, nCell * nBatch); + } + ZeroVector(*forgetGateScratch, nCell * nBatch); + ZeroVector(*cellScratch , nCell * nBatch); + ZeroVector(*outputGateScratch, nCell * nBatch); + } + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*inputGateScratchDecoder, + *inputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*inputLayerNormWeights, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + VectorBatchVectorAdd(*inputGateBiasTensor, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*forgetGateScratchDecoder, + *forgetGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + VectorBatchVectorAdd(*forgetGateBiasTensor, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + if (useLayerNorm) + { + MeanStddevNormalization(*cellScratchDecoder, + *cellScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*cellLayerNormWeights, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + VectorBatchVectorAdd(*cellBiasTensor, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + } + + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (descriptor.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*outputGateScratchDecoder, + *outputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*outputLayerNormWeights, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + VectorBatchVectorAdd(*outputGateBiasTensor, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (descriptor.m_ProjectionEnabled) + { + if (projectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (descriptor.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/Lstm.hpp b/src/backends/reference/workloads/Lstm.hpp new file mode 100644 index 0000000000..7d0a1d436e --- /dev/null +++ b/src/backends/reference/workloads/Lstm.hpp @@ -0,0 +1,61 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr>& inputData, + std::unique_ptr>& outputStateIn, + std::unique_ptr>& cellStateIn, + std::unique_ptr>& outputStateOut, + std::unique_ptr>& cellStateOut, + std::unique_ptr>& output, + std::unique_ptr>& cellStateOutDecoder, + std::unique_ptr>& outputDecoder, + std::unique_ptr>& inputToInputWeightsTensor, + std::unique_ptr>& inputToForgetWeightsTensor, + std::unique_ptr>& inputToCellWeightsTensor, + std::unique_ptr>& inputToOutputWeightsTensor, + std::unique_ptr>& recurrentToInputWeightsTensor, + std::unique_ptr>& recurrentToForgetWeightsTensor, + std::unique_ptr>& recurrentToCellWeightsTensor, + std::unique_ptr>& recurrentToOutputWeightsTensor, + std::unique_ptr>& cellToInputWeightsTensor, + std::unique_ptr>& cellToForgetWeightsTensor, + std::unique_ptr>& cellToOutputWeightsTensor, + std::unique_ptr>& inputGateBiasTensor, + std::unique_ptr>& forgetGateBiasTensor, + std::unique_ptr>& cellBiasTensor, + std::unique_ptr>& outputGateBiasTensor, + std::unique_ptr>& projectionWeightsTensor, + std::unique_ptr>& projectionBiasTensor, + std::unique_ptr>& inputLayerNormWeights, + std::unique_ptr>& forgetLayerNormWeights, + std::unique_ptr>& cellLayerNormWeights, + std::unique_ptr>& outputLayerNormWeights, + std::unique_ptr>& inputGateScratch, + std::unique_ptr>& cellScratch, + std::unique_ptr>& forgetGateScratch, + std::unique_ptr>& outputGateScratch, + std::unique_ptr>& inputGateScratchDecoder, + std::unique_ptr>& cellScratchDecoder, + std::unique_ptr>& forgetGateScratchDecoder, + std::unique_ptr>& outputGateScratchDecoder, + float layerNormEpsilon); + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index 3ddfd334b8..1ff6f50ed5 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -7,6 +7,7 @@ #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" +#include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" @@ -57,7 +58,6 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> outputStateOut = MakeEncoder(outputInfo, outputs[1]->Map()); std::unique_ptr> cellStateOut = MakeEncoder(outputInfo, outputs[2]->Map()); @@ -71,10 +71,7 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> cellStateIn = MakeDecoder(inputInfo, inputs[2]->Map()); const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; @@ -154,6 +151,9 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> cellLayerNormWeights; std::unique_ptr> outputLayerNormWeights; + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + if (useLayerNorm) { if (!useCifg) @@ -204,190 +204,49 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector 0) - { - Activation(*cellScratchDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, - nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*outputGateScratchDecoder, - *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*outputLayerNormWeights, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - VectorBatchVectorAdd(*outputGateBiasTensor, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - } - Activation(*outputGateScratchDecoder, *outputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellStateOutDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - - VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(*projectionBiasTensor, - nOutput, nBatch, *output); - } - MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, - nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); - } - } - else - { - CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); - } - - CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); + LstmImpl(m_Data.m_Parameters, + inputInfo, + outputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); } } //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp new file mode 100644 index 0000000000..311fa18f91 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefUnidirectionalSequenceLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +namespace armnn +{ + +RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefUnidirectionalSequenceLstmWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector inputs, + std::vector outputs) const +{ + TensorInfo inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); + const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); + TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorShape& inputShape = inputInfo.GetShape(); + TensorShape& outputShape= outputInfo.GetShape(); + auto inputTensor = reinterpret_cast(inputs[0]->Map()); + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute to time major + const PermutationVector& mappings = {1U, 0U, 2U}; + std::vector inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); + inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); + inputInfo.SetShape(inputShape); + armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); + + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + } + unsigned int maxTime = inputShape[0]; + unsigned int batchSize = inputShape[1]; + unsigned int outputSize = outputShape[2]; + unsigned int inputSize = inputShape[2]; + + TensorInfo scratchInfo = outputInfo; + scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); + + std::vector inputGateScratchBuffer; + std::vector cellScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + + std::vector outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); + std::vector cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); + + void* outputStateOutData = outputStateOutBuffer.data(); + void* cellStateOutData = cellStateOutBuffer.data(); + + std::unique_ptr> inputGateScratch; + std::unique_ptr> cellScratch = MakeEncoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratch = MakeEncoder(scratchInfo, forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratch = MakeEncoder(scratchInfo, outputGateScratchBuffer.data()); + + std::unique_ptr> inputGateScratchDecoder; + std::unique_ptr> cellScratchDecoder = MakeDecoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratchDecoder = MakeDecoder(scratchInfo, + forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratchDecoder = MakeDecoder(scratchInfo, + outputGateScratchBuffer.data()); + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; + + if (!useCifg) + { + inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); + inputGateScratch = MakeEncoder(scratchInfo, inputGateScratchBuffer.data()); + inputGateScratchDecoder = MakeDecoder(scratchInfo, inputGateScratchBuffer.data()); + } + + std::unique_ptr> outputStateOut = MakeEncoder(outputStateInfo, outputStateOutData); + std::unique_ptr> cellStateOut = MakeEncoder(cellStateInfo, cellStateOutData); + std::unique_ptr> cellStateOutDecoder = MakeDecoder(cellStateInfo, cellStateOutData); + + TensorInfo lstmInputInfo = inputInfo; + TensorShape batchInputShape = TensorShape({batchSize, inputSize}); + lstmInputInfo.SetShape(batchInputShape); + + TensorInfo lstmOutputInfo = outputInfo; + lstmOutputInfo.SetShape({batchSize, outputSize}); + + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + unsigned int nOutput = recurrentToOutputWeightsShape[1]; + auto outputStateInData = inputs[1]->Map(); + std::unique_ptr> outputStateIn = MakeDecoder(outputStateInfo, outputStateInData); + + auto cellStateInData = inputs[2]->Map(); + std::unique_ptr> cellStateIn = MakeDecoder(cellStateInfo, cellStateInData); + + auto currentInputData = reinterpret_cast(inputs[0]->Map()); + std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); + auto currentOutputData = reinterpret_cast(outputs[0]->Map()); + std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); + std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + std::unique_ptr> inputToInputWeightsTensor; + std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> recurrentToInputWeightsTensor; + std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> inputGateBiasTensor; + std::unique_ptr> forgetGateBiasTensor = MakeDecoder( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor()); + std::unique_ptr> cellBiasTensor = MakeDecoder( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor()); + std::unique_ptr> outputGateBiasTensor = MakeDecoder( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor()); + + std::unique_ptr> cellToInputWeightsTensor; + std::unique_ptr> cellToForgetWeightsTensor; + std::unique_ptr> cellToOutputWeightsTensor; + + std::unique_ptr> projectionWeightsTensor; + std::unique_ptr> projectionBiasTensor; + + std::unique_ptr> inputLayerNormWeights; + std::unique_ptr> forgetLayerNormWeights; + std::unique_ptr> cellLayerNormWeights; + std::unique_ptr> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor()); + } + forgetLayerNormWeights = MakeDecoder( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor()); + cellLayerNormWeights = MakeDecoder( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor()); + outputLayerNormWeights = MakeDecoder( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor()); + } + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor()); + inputGateBiasTensor = MakeDecoder( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor()); + recurrentToInputWeightsTensor = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor()); + cellToOutputWeightsTensor = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor()); + } + } + + unsigned int batchInputSize = batchSize * inputSize; + unsigned int batchOutputSize = batchSize * nOutput; + + for (unsigned int t = 0; t < maxTime; ++t) + { + LstmImpl(m_Data.m_Parameters, + lstmInputInfo, + lstmOutputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); + + currentInputData += batchInputSize; + inputData = MakeDecoder(lstmInputInfo, currentInputData); + currentOutputData += batchOutputSize; + output = MakeEncoder(lstmOutputInfo, currentOutputData); + outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + // Assign output state out to the next output state in + outputStateIn = MakeDecoder(outputStateInfo, outputStateOutData); + + // Assign cell state out to the next cell state in + cellStateIn = MakeDecoder(cellStateInfo, cellStateOutData); + } + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute Output back to batch major + const PermutationVector& mappings = {1U, 0U, 2U}; + auto outputData = reinterpret_cast(outputs[0]->Map()); + std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); + } +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp new file mode 100644 index 0000000000..8ba7bdc0c6 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp @@ -0,0 +1,56 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload +{ +public: + explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + + +private: + void Execute(std::vector inputs, std::vector outputs) const; + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; + std::unique_ptr m_InputLayerNormWeights; + std::unique_ptr m_ForgetLayerNormWeights; + std::unique_ptr m_CellLayerNormWeights; + std::unique_ptr m_OutputLayerNormWeights; + + float m_LayerNormEpsilon = static_cast(1e-8); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index afe63d13c0..d3ae58ea15 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -69,6 +69,7 @@ #include "RefSpaceToDepthWorkload.hpp" #include "RefTransposeConvolution2dWorkload.hpp" #include "RefTransposeWorkload.hpp" +#include "RefUnidirectionalSequenceLstmWorkload.hpp" #include "RefWorkloadUtils.hpp" #include "Resize.hpp" #include "Softmax.hpp" -- cgit v1.2.1