aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefLstmWorkload.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /src/backends/reference/workloads/RefLstmWorkload.cpp
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefLstmWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.cpp100
1 files changed, 88 insertions, 12 deletions
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
index f8ebc58f6e..70b3443d88 100644
--- a/src/backends/reference/workloads/RefLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -32,6 +32,10 @@ RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const Wo
, m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
, m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
, m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
+ , m_InputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
+ , m_ForgetLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
+ , m_CellLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
+ , m_OutputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
{}
void RefLstmWorkload::Execute() const
@@ -62,8 +66,9 @@ void RefLstmWorkload::Execute() const
const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
- const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
- const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+ const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
+ const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+ const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
// Index the scratch buffers pointers to the global scratch buffer.
std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
@@ -134,6 +139,26 @@ void RefLstmWorkload::Execute() const
std::unique_ptr<Decoder<float>> projectionWeightsTensor;
std::unique_ptr<Decoder<float>> projectionBiasTensor;
+ std::unique_ptr<Decoder<float>> inputLayerNormWeights;
+ std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
+ std::unique_ptr<Decoder<float>> cellLayerNormWeights;
+ std::unique_ptr<Decoder<float>> outputLayerNormWeights;
+
+ if (useLayerNorm)
+ {
+ if (!useCifg)
+ {
+ inputLayerNormWeights = MakeDecoder<float>(
+ m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetTensor<void>());
+ }
+ forgetLayerNormWeights = MakeDecoder<float>(
+ m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetTensor<void>());
+ cellLayerNormWeights = MakeDecoder<float>(
+ m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetTensor<void>());
+ outputLayerNormWeights = MakeDecoder<float>(
+ m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetTensor<void>());
+ }
+
if (!useCifg)
{
inputToInputWeightsTensor = MakeDecoder<float>(
@@ -169,18 +194,32 @@ void RefLstmWorkload::Execute() const
}
}
- // Initialize scratch buffers with bias.
- if (!useCifg)
+ if (!useLayerNorm)
{
- VectorBatchVectorAssign(*inputGateBiasTensor,
- nCell, nBatch, *inputGateScratch);
+ // Initialize scratch buffers with bias.
+ if (!useCifg)
+ {
+ VectorBatchVectorAssign(*inputGateBiasTensor,
+ nCell, nBatch, *inputGateScratch);
+ }
+ VectorBatchVectorAssign(*forgetGateBiasTensor,
+ nCell, nBatch, *forgetGateScratch);
+ VectorBatchVectorAssign(*cellBiasTensor,
+ nCell, nBatch, *cellScratch);
+ VectorBatchVectorAssign(*outputGateBiasTensor,
+ nCell, nBatch, *outputGateScratch);
+ }
+ else
+ {
+ // Initialize scratch buffers with zeroes.
+ if (!useCifg)
+ {
+ ZeroVector(*inputGateScratch, nCell * nBatch);
+ }
+ ZeroVector(*forgetGateScratch, nCell * nBatch);
+ ZeroVector(*cellScratch , nCell * nBatch);
+ ZeroVector(*outputGateScratch, nCell * nBatch);
}
- VectorBatchVectorAssign(*forgetGateBiasTensor,
- nCell, nBatch, *forgetGateScratch);
- VectorBatchVectorAssign(*cellBiasTensor,
- nCell, nBatch, *cellScratch);
- VectorBatchVectorAssign(*outputGateBiasTensor,
- nCell, nBatch, *outputGateScratch);
// For each batch and cell: compute input_weight * input.
if (!useCifg)
@@ -216,6 +255,15 @@ void RefLstmWorkload::Execute() const
VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
nCell, *cellStateIn, nBatch, *inputGateScratch);
}
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*inputGateScratchDecoder,
+ *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
+ nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+ VectorBatchVectorAdd(*inputGateBiasTensor,
+ nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+ }
Activation(*inputGateScratchDecoder, *inputGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);
@@ -227,11 +275,30 @@ void RefLstmWorkload::Execute() const
VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
*cellStateIn, nBatch, *forgetGateScratch);
}
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*forgetGateScratchDecoder,
+ *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
+ nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+ VectorBatchVectorAdd(*forgetGateBiasTensor,
+ nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+ }
Activation(*forgetGateScratchDecoder, *forgetGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);
// For each batch and cell: update the cell.
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*cellScratchDecoder,
+ *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
+ nCell, *cellScratchDecoder, nBatch, *cellScratch);
+ VectorBatchVectorAdd(*cellBiasTensor,
+ nCell, *cellScratchDecoder, nBatch, *cellScratch);
+ }
+
VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
@@ -267,6 +334,15 @@ void RefLstmWorkload::Execute() const
VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
}
+ if (useLayerNorm)
+ {
+ MeanStddevNormalization(*outputGateScratchDecoder,
+ *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+ VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
+ nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+ VectorBatchVectorAdd(*outputGateBiasTensor,
+ nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+ }
Activation(*outputGateScratchDecoder, *outputGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);