From eb2b329b761ce3206505ed8d2eab071a2f97d5e7 Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Tue, 7 May 2019 12:02:30 +0100 Subject: IVGCVSW-2997 Refactor reference LSTM workload Signed-off-by: Nattapat Chaimanowong Change-Id: I6883f878d9f701a55153292769d2fc0530d2529e --- .../reference/workloads/RefLstmWorkload.cpp | 307 +++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 src/backends/reference/workloads/RefLstmWorkload.cpp (limited to 'src/backends/reference/workloads/RefLstmWorkload.cpp') diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp new file mode 100644 index 0000000000..f8ebc58f6e --- /dev/null +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) +{} + +void RefLstmWorkload::Execute() const +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + std::unique_ptr> outputStateOut = MakeEncoder(outputInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr> cellStateOut = MakeEncoder(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr> output = MakeEncoder(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr> cellStateOutDecoder = MakeDecoder(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr> outputDecoder = MakeDecoder(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr> inputData = MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputStateIn = MakeDecoder(inputInfo, m_Data.m_Inputs[1]->Map()); + std::unique_ptr> cellStateIn = MakeDecoder(inputInfo, m_Data.m_Inputs[2]->Map()); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; + const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + + // Index the scratch buffers pointers to the global scratch buffer. + std::unique_ptr> inputGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> forgetGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> outputGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + std::unique_ptr> inputGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> forgetGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> outputGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + if (useCifg) + { + *cellScratch += (0 * nCell * nBatch); + *forgetGateScratch += (1 * nCell * nBatch); + *outputGateScratch += (2 * nCell * nBatch); + + *cellScratchDecoder += (0 * nCell * nBatch); + *forgetGateScratchDecoder += (1 * nCell * nBatch); + *outputGateScratchDecoder += (2 * nCell * nBatch); + } + else + { + *inputGateScratch += (0 * nCell * nBatch); + *cellScratch += (1 * nCell * nBatch); + *forgetGateScratch += (2 * nCell * nBatch); + *outputGateScratch += (3 * nCell * nBatch); + + *inputGateScratchDecoder += (0 * nCell * nBatch); + *cellScratchDecoder += (1 * nCell * nBatch); + *forgetGateScratchDecoder += (2 * nCell * nBatch); + *outputGateScratchDecoder += (3 * nCell * nBatch); + } + + std::unique_ptr> inputToInputWeightsTensor; + std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor()); + std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor()); + std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor()); + + std::unique_ptr> recurrentToInputWeightsTensor; + std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor()); + + std::unique_ptr> inputGateBiasTensor; + std::unique_ptr> forgetGateBiasTensor = MakeDecoder( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor()); + std::unique_ptr> cellBiasTensor = MakeDecoder( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor()); + std::unique_ptr> outputGateBiasTensor = MakeDecoder( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor()); + + std::unique_ptr> cellToInputWeightsTensor; + std::unique_ptr> cellToForgetWeightsTensor; + std::unique_ptr> cellToOutputWeightsTensor; + + std::unique_ptr> projectionWeightsTensor; + std::unique_ptr> projectionBiasTensor; + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor()); + inputGateBiasTensor = MakeDecoder( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor()); + recurrentToInputWeightsTensor = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor()); + cellToOutputWeightsTensor = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor()); + } + } + + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + if (m_ProjectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn -- cgit v1.2.1