aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefLstmWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefLstmWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.cpp307
1 files changed, 307 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
new file mode 100644
index 0000000000..f8ebc58f6e
--- /dev/null
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -0,0 +1,307 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefLstmWorkload.hpp"
+#include "Activation.hpp"
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+#include "LstmUtils.hpp"
+#include "RefWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
+ : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
+ , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
+ , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
+ , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
+ , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
+ , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
+ , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
+ , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
+ , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
+ , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
+ , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
+ , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
+ , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
+ , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
+ , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
+ , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
+ , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
+ , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
+{}
+
+void RefLstmWorkload::Execute() const
+{
+ // This is a porting of the LSTM::Eval() method in the Android code base
+ // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ const DataType& outputType = outputInfo.GetDataType();
+
+ std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map());
+ std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+ std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+ std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+ std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+ std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+ std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map());
+ std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map());
+
+ const uint32_t nBatch = inputShape[0];
+ const uint32_t nInput = inputShape[1];
+
+ const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
+ const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
+
+ const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
+ const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+ std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> cellScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+ if (useCifg)
+ {
+ *cellScratch += (0 * nCell * nBatch);
+ *forgetGateScratch += (1 * nCell * nBatch);
+ *outputGateScratch += (2 * nCell * nBatch);
+
+ *cellScratchDecoder += (0 * nCell * nBatch);
+ *forgetGateScratchDecoder += (1 * nCell * nBatch);
+ *outputGateScratchDecoder += (2 * nCell * nBatch);
+ }
+ else
+ {
+ *inputGateScratch += (0 * nCell * nBatch);
+ *cellScratch += (1 * nCell * nBatch);
+ *forgetGateScratch += (2 * nCell * nBatch);
+ *outputGateScratch += (3 * nCell * nBatch);
+
+ *inputGateScratchDecoder += (0 * nCell * nBatch);
+ *cellScratchDecoder += (1 * nCell * nBatch);
+ *forgetGateScratchDecoder += (2 * nCell * nBatch);
+ *outputGateScratchDecoder += (3 * nCell * nBatch);
+ }
+
+ std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
+ m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
+ m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
+ m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> inputGateBiasTensor;
+ std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
+ m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
+ m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
+ m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
+ std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
+
+ std::unique_ptr<Decoder<float>> projectionWeightsTensor;
+ std::unique_ptr<Decoder<float>> projectionBiasTensor;
+
+ if (!useCifg)
+ {
+ inputToInputWeightsTensor = MakeDecoder<float>(
+ m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
+ inputGateBiasTensor = MakeDecoder<float>(
+ m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>());
+ recurrentToInputWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
+ }
+
+ if (usePeephole)
+ {
+ cellToForgetWeightsTensor = MakeDecoder<float>(
+ m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
+ cellToOutputWeightsTensor = MakeDecoder<float>(
+ m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
+ }
+
+ if (!useCifg && usePeephole)
+ {
+ cellToInputWeightsTensor = MakeDecoder<float>(
+ m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
+ }
+
+ if (m_Data.m_Parameters.m_ProjectionEnabled)
+ {
+ projectionWeightsTensor = MakeDecoder<float>(
+ m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
+ if (m_ProjectionBiasTensor)
+ {
+ projectionBiasTensor = MakeDecoder<float>(
+ m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
+ }
+ }
+
+ // Initialize scratch buffers with bias.
+ if (!useCifg)
+ {
+ VectorBatchVectorAssign(*inputGateBiasTensor,
+ nCell, nBatch, *inputGateScratch);
+ }
+ VectorBatchVectorAssign(*forgetGateBiasTensor,
+ nCell, nBatch, *forgetGateScratch);
+ VectorBatchVectorAssign(*cellBiasTensor,
+ nCell, nBatch, *cellScratch);
+ VectorBatchVectorAssign(*outputGateBiasTensor,
+ nCell, nBatch, *outputGateScratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *outputGateScratch);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
+
+ // For each batch and cell: update input gate.
+ if (!useCifg)
+ {
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
+ nCell, *cellStateIn, nBatch, *inputGateScratch);
+ }
+ Activation(*inputGateScratchDecoder, *inputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
+ *cellStateIn, nBatch, *forgetGateScratch);
+ }
+ Activation(*forgetGateScratchDecoder, *forgetGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ // For each batch and cell: update the cell.
+ VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
+
+ ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
+ float a = 0;
+ float b = 0;
+ SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
+
+ if (m_Data.m_Parameters.m_ActivationFunc > 0)
+ {
+ Activation(*cellScratchDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+ if (useCifg)
+ {
+ Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ else
+ {
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
+ {
+ ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
+ nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
+ }
+ Activation(*outputGateScratchDecoder, *outputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ if (m_Data.m_Parameters.m_ActivationFunc > 0)
+ {
+ Activation(*cellStateOutDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+
+ VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
+
+ // For each batch: update the projection and output_state.
+ if (m_Data.m_Parameters.m_ProjectionEnabled)
+ {
+ if (m_ProjectionBiasTensor)
+ {
+ VectorBatchVectorAssign(*projectionBiasTensor,
+ nOutput, nBatch, *output);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
+ nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
+
+ if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
+ {
+ ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
+ }
+ }
+ else
+ {
+ CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
+ }
+
+ CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
+}
+
+} //namespace armnn