diff options
Diffstat (limited to 'src/backends/reference/workloads/RefLstmWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefLstmWorkload.cpp | 235 |
1 files changed, 47 insertions, 188 deletions
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index 3ddfd334b8..1ff6f50ed5 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -7,6 +7,7 @@ #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" +#include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" @@ -57,7 +58,6 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); const TensorShape& inputShape = inputInfo.GetShape(); - const DataType& outputType = outputInfo.GetDataType(); std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map()); std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map()); @@ -71,10 +71,7 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map()); const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; @@ -154,6 +151,9 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT std::unique_ptr<Decoder<float>> cellLayerNormWeights; std::unique_ptr<Decoder<float>> outputLayerNormWeights; + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + if (useLayerNorm) { if (!useCifg) @@ -204,190 +204,49 @@ void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT } } - if (!useLayerNorm) - { - // Initialize scratch buffers with bias. - if (!useCifg) - { - VectorBatchVectorAssign(*inputGateBiasTensor, - nCell, nBatch, *inputGateScratch); - } - VectorBatchVectorAssign(*forgetGateBiasTensor, - nCell, nBatch, *forgetGateScratch); - VectorBatchVectorAssign(*cellBiasTensor, - nCell, nBatch, *cellScratch); - VectorBatchVectorAssign(*outputGateBiasTensor, - nCell, nBatch, *outputGateScratch); - } - else - { - // Initialize scratch buffers with zeroes. - if (!useCifg) - { - ZeroVector(*inputGateScratch, nCell * nBatch); - } - ZeroVector(*forgetGateScratch, nCell * nBatch); - ZeroVector(*cellScratch , nCell * nBatch); - ZeroVector(*outputGateScratch, nCell * nBatch); - } - - // For each batch and cell: compute input_weight * input. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, - nCell, nInput, *inputData, nBatch, *inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, - nCell, nInput, *inputData, nBatch, *forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, - nCell, nInput, *inputData, nBatch, *cellScratch); - MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, - nCell, nInput, *inputData, nBatch, *outputGateScratch); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *cellScratch); - MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, - nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); - - // For each batch and cell: update input gate. - if (!useCifg) - { - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, - nCell, *cellStateIn, nBatch, *inputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*inputGateScratchDecoder, - *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*inputLayerNormWeights, - nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); - VectorBatchVectorAdd(*inputGateBiasTensor, - nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); - } - Activation(*inputGateScratchDecoder, *inputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - } - - // For each batch and cell: update forget gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, - *cellStateIn, nBatch, *forgetGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*forgetGateScratchDecoder, - *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, - nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); - VectorBatchVectorAdd(*forgetGateBiasTensor, - nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); - } - Activation(*forgetGateScratchDecoder, *forgetGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - // For each batch and cell: update the cell. - if (useLayerNorm) - { - MeanStddevNormalization(*cellScratchDecoder, - *cellScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*cellLayerNormWeights, - nCell, *cellScratchDecoder, nBatch, *cellScratch); - VectorBatchVectorAdd(*cellBiasTensor, - nCell, *cellScratchDecoder, nBatch, *cellScratch); - } - - VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); - - ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; - float a = 0; - float b = 0; - SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellScratchDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, - nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*outputGateScratchDecoder, - *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*outputLayerNormWeights, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - VectorBatchVectorAdd(*outputGateBiasTensor, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - } - Activation(*outputGateScratchDecoder, *outputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellStateOutDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - - VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(*projectionBiasTensor, - nOutput, nBatch, *output); - } - MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, - nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); - } - } - else - { - CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); - } - - CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); + LstmImpl(m_Data.m_Parameters, + inputInfo, + outputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); } } //namespace armnn |