aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Lstm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Lstm.cpp')
-rw-r--r--src/backends/reference/workloads/Lstm.cpp259
1 files changed, 259 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/Lstm.cpp b/src/backends/reference/workloads/Lstm.cpp
new file mode 100644
index 0000000000..c1fb2bf4aa
--- /dev/null
+++ b/src/backends/reference/workloads/Lstm.cpp
@@ -0,0 +1,259 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Activation.hpp"
+#include "Lstm.hpp"
+#include "LstmUtils.hpp"
+
+namespace armnn
+{
+
+void LstmImpl(const LstmDescriptor& descriptor,
+ const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const TensorShape& inputToOutputWeightsShape,
+ const TensorShape& recurrentToOutputWeightsShape,
+ std::unique_ptr<Decoder<float>>& inputData,
+ std::unique_ptr<Decoder<float>>& outputStateIn,
+ std::unique_ptr<Decoder<float>>& cellStateIn,
+ std::unique_ptr<Encoder<float>>& outputStateOut,
+ std::unique_ptr<Encoder<float>>& cellStateOut,
+ std::unique_ptr<Encoder<float>>& output,
+ std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
+ std::unique_ptr<Decoder<float>>& outputDecoder,
+ std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
+ std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
+ std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
+ std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
+ std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
+ std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
+ std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
+ std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
+ std::unique_ptr<Decoder<float>>& cellBiasTensor,
+ std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
+ std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
+ std::unique_ptr<Decoder<float>>& projectionBiasTensor,
+ std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
+ std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
+ std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
+ std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
+ std::unique_ptr<Encoder<float>>& inputGateScratch,
+ std::unique_ptr<Encoder<float>>& cellScratch,
+ std::unique_ptr<Encoder<float>>& forgetGateScratch,
+ std::unique_ptr<Encoder<float>>& outputGateScratch,
+ std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
+ std::unique_ptr<Decoder<float>>& cellScratchDecoder,
+ std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
+ std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
+ float layerNormEpsilon)
+{
+ // This is a porting of the LSTM::Eval() method in the Android code base
+ // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ const DataType& outputType = outputInfo.GetDataType();
+
+ const uint32_t nBatch = inputShape[0];
+ const uint32_t nInput = inputShape[1];
+
+ const uint32_t nCell = inputToOutputWeightsShape[0];
+ const uint32_t nOutput = recurrentToOutputWeightsShape[1];
+
+ const bool useCifg = descriptor.m_CifgEnabled;
+ const bool usePeephole = descriptor.m_PeepholeEnabled;
+ const bool useLayerNorm = descriptor.m_LayerNormEnabled;
+
+ if (!useLayerNorm)
+ {
+ // Initialize scratch buffers with bias.
+ if (!useCifg)
+ {
+ VectorBatchVectorAssign(*inputGateBiasTensor,
+ nCell, nBatch, *inputGateScratch);
+ }
+ VectorBatchVectorAssign(*forgetGateBiasTensor,
+ nCell, nBatch, *forgetGateScratch);
+ VectorBatchVectorAssign(*cellBiasTensor,
+ nCell, nBatch, *cellScratch);
+ VectorBatchVectorAssign(*outputGateBiasTensor,
+ nCell, nBatch, *outputGateScratch);
+ }
+ else
+ {
+ // Initialize scratch buffers with zeroes.
+ if (!useCifg)
+ {
+ ZeroVector(*inputGateScratch, nCell * nBatch);
+ }
+ ZeroVector(*forgetGateScratch, nCell * nBatch);
+ ZeroVector(*cellScratch , nCell * nBatch);
+ ZeroVector(*outputGateScratch, nCell * nBatch);
+ }
+
+ // For each batch and cell: compute input_weight * input.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *outputGateScratch);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
+
+ // For each batch and cell: update input gate.
+ if (!useCifg)
+ {
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
+ nCell, *cellStateIn, nBatch, *inputGateScratch);
+ }
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*inputGateScratchDecoder,
+ *inputGateScratch, nCell, nBatch, layerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
+ nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+ VectorBatchVectorAdd(*inputGateBiasTensor,
+ nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+ }
+ Activation(*inputGateScratchDecoder, *inputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
+ *cellStateIn, nBatch, *forgetGateScratch);
+ }
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*forgetGateScratchDecoder,
+ *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
+ nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+ VectorBatchVectorAdd(*forgetGateBiasTensor,
+ nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+ }
+ Activation(*forgetGateScratchDecoder, *forgetGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ // For each batch and cell: update the cell.
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*cellScratchDecoder,
+ *cellScratch, nCell, nBatch, layerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
+ nCell, *cellScratchDecoder, nBatch, *cellScratch);
+ VectorBatchVectorAdd(*cellBiasTensor,
+ nCell, *cellScratchDecoder, nBatch, *cellScratch);
+ }
+
+ VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
+
+ ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
+ float a = 0;
+ float b = 0;
+ SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b);
+
+ if (descriptor.m_ActivationFunc > 0)
+ {
+ Activation(*cellScratchDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+ if (useCifg)
+ {
+ Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ else
+ {
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ if (descriptor.m_ClippingThresCell > 0.0)
+ {
+ ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
+ nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
+ }
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*outputGateScratchDecoder,
+ *outputGateScratch, nCell, nBatch, layerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
+ nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+ VectorBatchVectorAdd(*outputGateBiasTensor,
+ nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+ }
+ Activation(*outputGateScratchDecoder, *outputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ if (descriptor.m_ActivationFunc > 0)
+ {
+ Activation(*cellStateOutDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+
+ VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
+
+ // For each batch: update the projection and output_state.
+ if (descriptor.m_ProjectionEnabled)
+ {
+ if (projectionBiasTensor)
+ {
+ VectorBatchVectorAssign(*projectionBiasTensor,
+ nOutput, nBatch, *output);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
+ nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
+
+ if (descriptor.m_ClippingThresProj > 0.0)
+ {
+ ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output);
+ }
+ }
+ else
+ {
+ CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
+ }
+
+ CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
+}
+
+} //namespace armnn