diff options
Diffstat (limited to 'src/backends/reference/workloads/RefLstmWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefLstmWorkload.cpp | 100 |
1 files changed, 88 insertions, 12 deletions
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index f8ebc58f6e..70b3443d88 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -32,6 +32,10 @@ RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const Wo , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights)) {} void RefLstmWorkload::Execute() const @@ -62,8 +66,9 @@ void RefLstmWorkload::Execute() const const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; - const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; - const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; // Index the scratch buffers pointers to the global scratch buffer. std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); @@ -134,6 +139,26 @@ void RefLstmWorkload::Execute() const std::unique_ptr<Decoder<float>> projectionWeightsTensor; std::unique_ptr<Decoder<float>> projectionBiasTensor; + std::unique_ptr<Decoder<float>> inputLayerNormWeights; + std::unique_ptr<Decoder<float>> forgetLayerNormWeights; + std::unique_ptr<Decoder<float>> cellLayerNormWeights; + std::unique_ptr<Decoder<float>> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder<float>( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetTensor<void>()); + } + forgetLayerNormWeights = MakeDecoder<float>( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetTensor<void>()); + cellLayerNormWeights = MakeDecoder<float>( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetTensor<void>()); + outputLayerNormWeights = MakeDecoder<float>( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetTensor<void>()); + } + if (!useCifg) { inputToInputWeightsTensor = MakeDecoder<float>( @@ -169,18 +194,32 @@ void RefLstmWorkload::Execute() const } } - // Initialize scratch buffers with bias. - if (!useCifg) + if (!useLayerNorm) { - VectorBatchVectorAssign(*inputGateBiasTensor, - nCell, nBatch, *inputGateScratch); + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + } + else + { + // Initialize scratch buffers with zeroes. + if (!useCifg) + { + ZeroVector(*inputGateScratch, nCell * nBatch); + } + ZeroVector(*forgetGateScratch, nCell * nBatch); + ZeroVector(*cellScratch , nCell * nBatch); + ZeroVector(*outputGateScratch, nCell * nBatch); } - VectorBatchVectorAssign(*forgetGateBiasTensor, - nCell, nBatch, *forgetGateScratch); - VectorBatchVectorAssign(*cellBiasTensor, - nCell, nBatch, *cellScratch); - VectorBatchVectorAssign(*outputGateBiasTensor, - nCell, nBatch, *outputGateScratch); // For each batch and cell: compute input_weight * input. if (!useCifg) @@ -216,6 +255,15 @@ void RefLstmWorkload::Execute() const VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, nCell, *cellStateIn, nBatch, *inputGateScratch); } + if (useLayerNorm) + { + MeanStddevNormalization(*inputGateScratchDecoder, + *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon); + VectorBatchVectorCwiseProduct(*inputLayerNormWeights, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + VectorBatchVectorAdd(*inputGateBiasTensor, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + } Activation(*inputGateScratchDecoder, *inputGateScratch, TensorInfo({nCell, nBatch}, outputType), ActivationFunction::Sigmoid, 0, 0); @@ -227,11 +275,30 @@ void RefLstmWorkload::Execute() const VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, *cellStateIn, nBatch, *forgetGateScratch); } + if (useLayerNorm) + { + MeanStddevNormalization(*forgetGateScratchDecoder, + *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon); + VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + VectorBatchVectorAdd(*forgetGateBiasTensor, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + } Activation(*forgetGateScratchDecoder, *forgetGateScratch, TensorInfo({nCell, nBatch}, outputType), ActivationFunction::Sigmoid, 0, 0); // For each batch and cell: update the cell. + if (useLayerNorm) + { + MeanStddevNormalization(*cellScratchDecoder, + *cellScratch, nCell, nBatch, m_LayerNormEpsilon); + VectorBatchVectorCwiseProduct(*cellLayerNormWeights, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + VectorBatchVectorAdd(*cellBiasTensor, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + } + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; @@ -267,6 +334,15 @@ void RefLstmWorkload::Execute() const VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); } + if (useLayerNorm) + { + MeanStddevNormalization(*outputGateScratchDecoder, + *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); + VectorBatchVectorCwiseProduct(*outputLayerNormWeights, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + VectorBatchVectorAdd(*outputGateBiasTensor, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + } Activation(*outputGateScratchDecoder, *outputGateScratch, TensorInfo({nCell, nBatch}, outputType), ActivationFunction::Sigmoid, 0, 0); |