diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 147 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 11 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 7 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 2 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/Lstm.cpp | 259 | ||||
-rw-r--r-- | src/backends/reference/workloads/Lstm.hpp | 61 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefLstmWorkload.cpp | 235 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 307 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp | 56 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 |
13 files changed, 918 insertions, 188 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1b05c4e0f4..2603371927 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1242,6 +1242,7 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, "Reference Lstm: input and outputStateOut types are mismatched"); supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); // check layer parameters @@ -2288,4 +2289,150 @@ bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional<TensorInfo>& hiddenStateOutput, + const Optional<TensorInfo>& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + IgnoreUnused(paramsInfo); + IgnoreUnused(outputStateIn); + IgnoreUnused(cellStateIn); + bool supported = true; + + if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + { + reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " + "and cell state output are not supported at the moment."; + } + + std::array<DataType, 1> supportedTypes = + { + DataType::Float32 + }; + + std::array<DataType, 1> supportedWeightTypes = + { + DataType::Float32 + }; + + // check inputs and outputs + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " + "are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " + "are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and InputGateBias types " + "are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToInputWeights " + "is not a supported type."); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToOutputWeights " + "is not a supported type."); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ProjectionWeights " + "is not a supported type."); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ProjectionBias types " + "are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputLayerNormWeights " + "is not a supported type."); + } + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights " + "is not a supported type."); + } + + return supported; +} + } // namespace armnn diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index c060f79b5a..a1b4dc7f47 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -370,6 +370,17 @@ public: const TensorInfo& output, const TransposeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + + bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional<TensorInfo>& hiddenStateOutput, + const Optional<TensorInfo>& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; }; } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 606f531630..16cf17cc79 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -712,4 +712,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTransposeConvolution2d( return std::make_unique<RefTransposeConvolution2dWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<RefUnidirectionalSequenceLstmWorkload>(descriptor, info);; +} + } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 2beffa77f3..113aca70ef 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -276,6 +276,10 @@ public: std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + private: template <typename F32Workload, typename U8Workload, typename QueueDescriptorType> std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index bf18284143..17ddbe0df1 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -37,6 +37,7 @@ BACKEND_SOURCES := \ workloads/Gather.cpp \ workloads/InstanceNorm.cpp \ workloads/LogSoftmax.cpp \ + workloads/Lstm.cpp \ workloads/LstmUtils.cpp \ workloads/Concatenate.cpp \ workloads/Pad.cpp \ @@ -95,6 +96,7 @@ BACKEND_SOURCES := \ workloads/RefSplitterWorkload.cpp \ workloads/RefTransposeConvolution2dWorkload.cpp \ workloads/RefTransposeWorkload.cpp \ + workloads/RefUnidirectionalSequenceLstmWorkload.cpp \ workloads/Resize.cpp \ workloads/Slice.cpp \ workloads/SpaceToBatchNd.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 45e3717268..0cf36f2c6e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -2330,4 +2330,16 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMax2Float32, ReduceMaxSimpleTest2<DataType:: ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMinFloat32, ReduceMinSimpleTest<DataType::Float32>) ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMinNegativeAxisFloat32, ReduceMinNegativeAxisTest<DataType::Float32>) +// Unidirectional Sequence Lstm +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32, + UnidirectionalSequenceLstmLayerFloat32Test) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajor, + UnidirectionalSequenceLstmLayerFloat32TimeMajorTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjection, + UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNorm, + UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjection, + UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest) + }
\ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 7a769e5246..b9f477cb6d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -42,6 +42,8 @@ list(APPEND armnnRefBackendWorkloads_sources Log.hpp LogSoftmax.cpp LogSoftmax.hpp + Lstm.cpp + Lstm.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -162,6 +164,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefTransposeConvolution2dWorkload.hpp RefTransposeWorkload.cpp RefTransposeWorkload.hpp + RefUnidirectionalSequenceLstmWorkload.cpp + RefUnidirectionalSequenceLstmWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp Resize.cpp diff --git a/src/backends/reference/workloads/Lstm.cpp b/src/backends/reference/workloads/Lstm.cpp new file mode 100644 index 0000000000..c1fb2bf4aa --- /dev/null +++ b/src/backends/reference/workloads/Lstm.cpp @@ -0,0 +1,259 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "Activation.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr<Decoder<float>>& inputData, + std::unique_ptr<Decoder<float>>& outputStateIn, + std::unique_ptr<Decoder<float>>& cellStateIn, + std::unique_ptr<Encoder<float>>& outputStateOut, + std::unique_ptr<Encoder<float>>& cellStateOut, + std::unique_ptr<Encoder<float>>& output, + std::unique_ptr<Decoder<float>>& cellStateOutDecoder, + std::unique_ptr<Decoder<float>>& outputDecoder, + std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& inputGateBiasTensor, + std::unique_ptr<Decoder<float>>& forgetGateBiasTensor, + std::unique_ptr<Decoder<float>>& cellBiasTensor, + std::unique_ptr<Decoder<float>>& outputGateBiasTensor, + std::unique_ptr<Decoder<float>>& projectionWeightsTensor, + std::unique_ptr<Decoder<float>>& projectionBiasTensor, + std::unique_ptr<Decoder<float>>& inputLayerNormWeights, + std::unique_ptr<Decoder<float>>& forgetLayerNormWeights, + std::unique_ptr<Decoder<float>>& cellLayerNormWeights, + std::unique_ptr<Decoder<float>>& outputLayerNormWeights, + std::unique_ptr<Encoder<float>>& inputGateScratch, + std::unique_ptr<Encoder<float>>& cellScratch, + std::unique_ptr<Encoder<float>>& forgetGateScratch, + std::unique_ptr<Encoder<float>>& outputGateScratch, + std::unique_ptr<Decoder<float>>& inputGateScratchDecoder, + std::unique_ptr<Decoder<float>>& cellScratchDecoder, + std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder, + std::unique_ptr<Decoder<float>>& outputGateScratchDecoder, + float layerNormEpsilon) +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = inputToOutputWeightsShape[0]; + const uint32_t nOutput = recurrentToOutputWeightsShape[1]; + + const bool useCifg = descriptor.m_CifgEnabled; + const bool usePeephole = descriptor.m_PeepholeEnabled; + const bool useLayerNorm = descriptor.m_LayerNormEnabled; + + if (!useLayerNorm) + { + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + } + else + { + // Initialize scratch buffers with zeroes. + if (!useCifg) + { + ZeroVector(*inputGateScratch, nCell * nBatch); + } + ZeroVector(*forgetGateScratch, nCell * nBatch); + ZeroVector(*cellScratch , nCell * nBatch); + ZeroVector(*outputGateScratch, nCell * nBatch); + } + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*inputGateScratchDecoder, + *inputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*inputLayerNormWeights, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + VectorBatchVectorAdd(*inputGateBiasTensor, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*forgetGateScratchDecoder, + *forgetGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + VectorBatchVectorAdd(*forgetGateBiasTensor, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + if (useLayerNorm) + { + MeanStddevNormalization(*cellScratchDecoder, + *cellScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*cellLayerNormWeights, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + VectorBatchVectorAdd(*cellBiasTensor, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + } + + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (descriptor.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*outputGateScratchDecoder, + *outputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*outputLayerNormWeights, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + VectorBatchVectorAdd(*outputGateBiasTensor, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (descriptor.m_ProjectionEnabled) + { + if (projectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (descriptor.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/Lstm.hpp b/src/backends/reference/workloads/Lstm.hpp new file mode 100644 index 0000000000..7d0a1d436e --- /dev/null +++ b/src/backends/reference/workloads/Lstm.hpp @@ -0,0 +1,61 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/TypesUtils.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr<Decoder<float>>& inputData, + std::unique_ptr<Decoder<float>>& outputStateIn, + std::unique_ptr<Decoder<float>>& cellStateIn, + std::unique_ptr<Encoder<float>>& outputStateOut, + std::unique_ptr<Encoder<float>>& cellStateOut, + std::unique_ptr<Encoder<float>>& output, + std::unique_ptr<Decoder<float>>& cellStateOutDecoder, + std::unique_ptr<Decoder<float>>& outputDecoder, + std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor, + std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor, + std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor, + std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor, + std::unique_ptr<Decoder<float>>& inputGateBiasTensor, + std::unique_ptr<Decoder<float>>& forgetGateBiasTensor, + std::unique_ptr<Decoder<float>>& cellBiasTensor, + std::unique_ptr<Decoder<float>>& outputGateBiasTensor, + std::unique_ptr<Decoder<float>>& projectionWeightsTensor, + std::unique_ptr<Decoder<float>>& projectionBiasTensor, + std::unique_ptr<Decoder<float>>& inputLayerNormWeights, + std::unique_ptr<Decoder<float>>& forgetLayerNormWeights, + std::unique_ptr<Decoder<float>>& cellLayerNormWeights, + std::unique_ptr<Decoder<float>>& outputLayerNormWeights, + std::unique_ptr<Encoder<float>>& inputGateScratch, + std::unique_ptr<Encoder<float>>& cellScratch, + std::unique_ptr<Encoder<float>>& forgetGateScratch, + std::unique_ptr<Encoder<float>>& outputGateScratch, + std::unique_ptr<Decoder<float>>& inputGateScratchDecoder, + std::unique_ptr<Decoder<float>>& cellScratchDecoder, + std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder, + std::unique_ptr<Decoder<float>>& outputGateScratchDecoder, + float layerNormEpsilon); + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index 3ddfd334b8..1ff6f50ed5 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -7,6 +7,7 @@ #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" +#include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" @@ -57,7 +58,6 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); const TensorShape& inputShape = inputInfo.GetShape(); - const DataType& outputType = outputInfo.GetDataType(); std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map()); std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map()); @@ -71,10 +71,7 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map()); const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; @@ -154,6 +151,9 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT std::unique_ptr<Decoder<float>> cellLayerNormWeights; std::unique_ptr<Decoder<float>> outputLayerNormWeights; + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + if (useLayerNorm) { if (!useCifg) @@ -204,190 +204,49 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT } } - if (!useLayerNorm) - { - // Initialize scratch buffers with bias. - if (!useCifg) - { - VectorBatchVectorAssign(*inputGateBiasTensor, - nCell, nBatch, *inputGateScratch); - } - VectorBatchVectorAssign(*forgetGateBiasTensor, - nCell, nBatch, *forgetGateScratch); - VectorBatchVectorAssign(*cellBiasTensor, - nCell, nBatch, *cellScratch); - VectorBatchVectorAssign(*outputGateBiasTensor, - nCell, nBatch, *outputGateScratch); - } - else - { - // Initialize scratch buffers with zeroes. - if (!useCifg) - { - ZeroVector(*inputGateScratch, nCell * nBatch); - } - ZeroVector(*forgetGateScratch, nCell * nBatch); - ZeroVector(*cellScratch , nCell * nBatch); - ZeroVector(*outputGateScratch, nCell * nBatch); - } - - // For each batch and cell: compute input_weight * input. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, - nCell, nInput, *inputData, nBatch, *inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, - nCell, nInput, *inputData, nBatch, *forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, - nCell, nInput, *inputData, nBatch, *cellScratch); - MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, - nCell, nInput, *inputData, nBatch, *outputGateScratch); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *cellScratch); - MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); - - // For each batch and cell: update input gate. - if (!useCifg) - { - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, - nCell, *cellStateIn, nBatch, *inputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*inputGateScratchDecoder, - *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*inputLayerNormWeights, - nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); - VectorBatchVectorAdd(*inputGateBiasTensor, - nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); - } - Activation(*inputGateScratchDecoder, *inputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - } - - // For each batch and cell: update forget gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, - *cellStateIn, nBatch, *forgetGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*forgetGateScratchDecoder, - *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, - nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); - VectorBatchVectorAdd(*forgetGateBiasTensor, - nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); - } - Activation(*forgetGateScratchDecoder, *forgetGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - // For each batch and cell: update the cell. - if (useLayerNorm) - { - MeanStddevNormalization(*cellScratchDecoder, - *cellScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*cellLayerNormWeights, - nCell, *cellScratchDecoder, nBatch, *cellScratch); - VectorBatchVectorAdd(*cellBiasTensor, - nCell, *cellScratchDecoder, nBatch, *cellScratch); - } - - VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); - - ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; - float a = 0; - float b = 0; - SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellScratchDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, - nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*outputGateScratchDecoder, - *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*outputLayerNormWeights, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - VectorBatchVectorAdd(*outputGateBiasTensor, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - } - Activation(*outputGateScratchDecoder, *outputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellStateOutDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - - VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(*projectionBiasTensor, - nOutput, nBatch, *output); - } - MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, - nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); - } - } - else - { - CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); - } - - CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); + LstmImpl(m_Data.m_Parameters, + inputInfo, + outputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); } } //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp new file mode 100644 index 0000000000..311fa18f91 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefUnidirectionalSequenceLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnnUtils/Permute.hpp> + +namespace armnn +{ + +RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefUnidirectionalSequenceLstmWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) const +{ + TensorInfo inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); + const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); + TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorShape& inputShape = inputInfo.GetShape(); + TensorShape& outputShape= outputInfo.GetShape(); + auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map()); + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute to time major + const PermutationVector& mappings = {1U, 0U, 2U}; + std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); + inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); + inputInfo.SetShape(inputShape); + armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); + + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + } + unsigned int maxTime = inputShape[0]; + unsigned int batchSize = inputShape[1]; + unsigned int outputSize = outputShape[2]; + unsigned int inputSize = inputShape[2]; + + TensorInfo scratchInfo = outputInfo; + scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); + + std::vector<float> inputGateScratchBuffer; + std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + + std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); + std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); + + void* outputStateOutData = outputStateOutBuffer.data(); + void* cellStateOutData = cellStateOutBuffer.data(); + + std::unique_ptr<Encoder<float>> inputGateScratch; + std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data()); + std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data()); + + std::unique_ptr<Decoder<float>> inputGateScratchDecoder; + std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo, + forgetGateScratchBuffer.data()); + std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo, + outputGateScratchBuffer.data()); + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; + + if (!useCifg) + { + inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); + inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data()); + inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data()); + } + + std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData); + std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData); + std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData); + + TensorInfo lstmInputInfo = inputInfo; + TensorShape batchInputShape = TensorShape({batchSize, inputSize}); + lstmInputInfo.SetShape(batchInputShape); + + TensorInfo lstmOutputInfo = outputInfo; + lstmOutputInfo.SetShape({batchSize, outputSize}); + + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + unsigned int nOutput = recurrentToOutputWeightsShape[1]; + auto outputStateInData = inputs[1]->Map(); + std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData); + + auto cellStateInData = inputs[2]->Map(); + std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData); + + auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map()); + std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); + auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map()); + std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); + std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); + + std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; + std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; + std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> inputGateBiasTensor; + std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>()); + std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>()); + + std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; + std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; + std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; + + std::unique_ptr<Decoder<float>> projectionWeightsTensor; + std::unique_ptr<Decoder<float>> projectionBiasTensor; + + std::unique_ptr<Decoder<float>> inputLayerNormWeights; + std::unique_ptr<Decoder<float>> forgetLayerNormWeights; + std::unique_ptr<Decoder<float>> cellLayerNormWeights; + std::unique_ptr<Decoder<float>> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder<float>( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>()); + } + forgetLayerNormWeights = MakeDecoder<float>( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>()); + cellLayerNormWeights = MakeDecoder<float>( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>()); + outputLayerNormWeights = MakeDecoder<float>( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>()); + } + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder<float>( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); + inputGateBiasTensor = MakeDecoder<float>( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>()); + recurrentToInputWeightsTensor = MakeDecoder<float>( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder<float>( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); + cellToOutputWeightsTensor = MakeDecoder<float>( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder<float>( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder<float>( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder<float>( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); + } + } + + unsigned int batchInputSize = batchSize * inputSize; + unsigned int batchOutputSize = batchSize * nOutput; + + for (unsigned int t = 0; t < maxTime; ++t) + { + LstmImpl(m_Data.m_Parameters, + lstmInputInfo, + lstmOutputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); + + currentInputData += batchInputSize; + inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); + currentOutputData += batchOutputSize; + output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); + outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); + + // Assign output state out to the next output state in + outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData); + + // Assign cell state out to the next cell state in + cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData); + } + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute Output back to batch major + const PermutationVector& mappings = {1U, 0U, 2U}; + auto outputData = reinterpret_cast<float*>(outputs[0]->Map()); + std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements()); + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); + } +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp new file mode 100644 index 0000000000..8ba7bdc0c6 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp @@ -0,0 +1,56 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/TypesUtils.hpp> + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor> +{ +public: + explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + + +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; + std::unique_ptr<ScopedTensorHandle> m_InputToInputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_InputToForgetWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_InputToCellWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_InputToOutputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_RecurrentToInputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_RecurrentToForgetWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_RecurrentToCellWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_RecurrentToOutputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_CellToInputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_CellToForgetWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_CellToOutputWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_InputGateBiasTensor; + std::unique_ptr<ScopedTensorHandle> m_ForgetGateBiasTensor; + std::unique_ptr<ScopedTensorHandle> m_CellBiasTensor; + std::unique_ptr<ScopedTensorHandle> m_OutputGateBiasTensor; + std::unique_ptr<ScopedTensorHandle> m_ProjectionWeightsTensor; + std::unique_ptr<ScopedTensorHandle> m_ProjectionBiasTensor; + std::unique_ptr<ScopedTensorHandle> m_InputLayerNormWeights; + std::unique_ptr<ScopedTensorHandle> m_ForgetLayerNormWeights; + std::unique_ptr<ScopedTensorHandle> m_CellLayerNormWeights; + std::unique_ptr<ScopedTensorHandle> m_OutputLayerNormWeights; + + float m_LayerNormEpsilon = static_cast<float>(1e-8); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index afe63d13c0..d3ae58ea15 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -69,6 +69,7 @@ #include "RefSpaceToDepthWorkload.hpp" #include "RefTransposeConvolution2dWorkload.hpp" #include "RefTransposeWorkload.hpp" +#include "RefUnidirectionalSequenceLstmWorkload.hpp" #include "RefWorkloadUtils.hpp" #include "Resize.hpp" #include "Softmax.hpp" |