diff options
author | James Conroy <james.conroy@arm.com> | 2020-04-29 20:01:10 +0100 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-05-02 16:44:33 +0000 |
commit | 4f1f899da140bb0490cf7e404daeaf1206f4db8b (patch) | |
tree | dc6d1215440e0efa677d47a4b944882d72e12cc9 /src/backends/reference/workloads | |
parent | 56e1a5f68213c9134826ad14c6e1fb4c0d41fb46 (diff) | |
download | armnn-4f1f899da140bb0490cf7e404daeaf1206f4db8b.tar.gz |
IVGCVSW-4449 Add QLstm ref implementation
* Adds ref implemenation for new HAL 1.3
operator, QLstm.
* Adds Layer and CreateWorkload unit tests.
* Adds WorkloadData validate for QLstm.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I8a721f07ff06105e6495a1a0561b9503aa8146dc
Diffstat (limited to 'src/backends/reference/workloads')
4 files changed, 576 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 9f3880e077..1abdb0bd82 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -123,6 +123,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefPreluWorkload.hpp RefQuantizeWorkload.cpp RefQuantizeWorkload.hpp + RefQLstmWorkload.cpp + RefQLstmWorkload.hpp RefReshapeWorkload.cpp RefReshapeWorkload.hpp RefResizeBilinearWorkload.cpp diff --git a/src/backends/reference/workloads/RefQLstmWorkload.cpp b/src/backends/reference/workloads/RefQLstmWorkload.cpp new file mode 100644 index 0000000000..34d048b0cb --- /dev/null +++ b/src/backends/reference/workloads/RefQLstmWorkload.cpp @@ -0,0 +1,519 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefQLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload<QLstmQueueDescriptor>(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) + + , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) + + , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) + + , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) + + , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) + + , m_InputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefQLstmWorkload::Execute() const +{ + // This is a porting of the QLSTM::Execute() method in the Android code base + // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all + // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. + // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp + const DataType& internalType = armnn::DataType::QSymmS16; + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]); + + const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); + const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); + + // Infer numBatches, inputSize, outputSize and numUnits + const uint32_t numBatches = inputShape[0]; + const uint32_t inputSize = inputShape[1]; + const uint32_t outputSize = outputStateInShape[1]; + const uint32_t numUnits = cellStateInShape[1]; + + // Optional param settings + const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; + const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; + const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; + const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; + + // Input decoders + std::unique_ptr<Decoder<float>> inputDecoder = + MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Decoder<float>> outputStateInDecoder = + MakeDecoder<float>(outputStateInInfo, m_Data.m_Inputs[1]->Map()); + std::unique_ptr<Decoder<float>> cellStateInDecoder = + MakeDecoder<float>(cellStateInInfo, m_Data.m_Inputs[2]->Map()); + + // Output decoders + std::unique_ptr<Decoder<float>> outputStateOutDecoder = + MakeDecoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Decoder<float>> cellStateOutDecoder = + MakeDecoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr<Decoder<float>> outputDecoder = + MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); + + // Output encoders + std::unique_ptr<Encoder<float>> outputStateOutEncoder = + MakeEncoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Encoder<float>> cellStateOutEncoder = + MakeEncoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr<Encoder<float>> outputEncoder = + MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); + + // Weights decoders + std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>()); + + std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>()); + + // Optional CIFG params + std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder; + std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder; + std::unique_ptr<Decoder<float>> inputGateBiasDecoder; + + // Optional Peephole params + std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder; + std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder; + std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder; + + // Optional Projection params + std::unique_ptr<Decoder<float>> projectionWeightsDecoder; + std::unique_ptr<Decoder<float>> projectionBiasDecoder; + + // Optional Layer Norm params + std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder; + std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder; + std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder; + std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder; + + // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) + std::unique_ptr<Decoder<float>> forgetGateBiasDecoder; + std::unique_ptr<Decoder<float>> cellGateBiasDecoder; + std::unique_ptr<Decoder<float>> outputGateBiasDecoder; + + // Int16 vectors for internal state data (to be decoded/encoded) + const uint32_t stateTensorSize = numBatches * numUnits; + std::vector<int16_t> inputGateData(stateTensorSize); + std::vector<int16_t> cellGateData(stateTensorSize); + std::vector<int16_t> forgetGateData(stateTensorSize); + std::vector<int16_t> outputGateData(stateTensorSize); + std::vector<int32_t> hiddenStateData(stateTensorSize); + + armnn::TensorInfo inputGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); + armnn::TensorInfo cellGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); + armnn::TensorInfo forgetGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); + armnn::TensorInfo outputGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); + armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, + armnn::DataType::QAsymmS8, + m_Data.m_Parameters.m_HiddenStateScale, + m_Data.m_Parameters.m_HiddenStateZeroPoint); + + // Decoders/Encoders for internal states + std::unique_ptr<Decoder<float>> inputGateDecoder = + MakeDecoder<float>(inputGateInfo, inputGateData.data()); + std::unique_ptr<Decoder<float>> cellGateDecoder = + MakeDecoder<float>(cellGateInfo, cellGateData.data()); + std::unique_ptr<Decoder<float>> forgetGateDecoder = + MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); + std::unique_ptr<Decoder<float>> outputGateDecoder = + MakeDecoder<float>(outputGateInfo, outputGateData.data()); + std::unique_ptr<Decoder<float>> hiddenStateDecoder = + MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data()); + + std::unique_ptr<Encoder<float>> inputGateEncoder = + MakeEncoder<float>(inputGateInfo, inputGateData.data()); + std::unique_ptr<Encoder<float>> cellGateEncoder = + MakeEncoder<float>(cellGateInfo, cellGateData.data()); + std::unique_ptr<Encoder<float>> forgetGateEncoder = + MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); + std::unique_ptr<Encoder<float>> outputGateEncoder = + MakeEncoder<float>(outputGateInfo, outputGateData.data()); + std::unique_ptr<Encoder<float>> hiddenStateEncoder = + MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data()); + + // Create decoders for optional params if they are enabled + if (!cifgEnabled) + { + inputToInputWeightsDecoder = MakeDecoder<float>( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>()); + recurrentToInputWeightsDecoder = MakeDecoder<float>( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>()); + } + + if (peepholeEnabled) + { + if (!cifgEnabled) + { + cellToInputWeightsDecoder = MakeDecoder<float>( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>()); + } + cellToForgetWeightsDecoder = MakeDecoder<float>( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>()); + cellToOutputWeightsDecoder = MakeDecoder<float>( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>()); + } + + if (projectionEnabled) + { + projectionWeightsDecoder = MakeDecoder<float>( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>()); + if (m_ProjectionBiasTensor) + { + projectionBiasDecoder = MakeDecoder<float>( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>()); + } + } + + if (layerNormEnabled) + { + if (!cifgEnabled) + { + inputLayerNormWeightsDecoder = MakeDecoder<float>( + m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>()); + + // Bias only used if layer norm enabled + armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + inputGateBiasDecoder = MakeDecoder<float>( + inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>()); + } + + forgetLayerNormWeightsDecoder = MakeDecoder<float>( + m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>()); + cellLayerNormWeightsDecoder = MakeDecoder<float>( + m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>()); + outputLayerNormWeightsDecoder = MakeDecoder<float>( + m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>()); + + // Bias only used if layer norm enabled + armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + forgetGateBiasDecoder = MakeDecoder<float>( + forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>()); + + armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + cellGateBiasDecoder = MakeDecoder<float>( + cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>()); + + armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + outputGateBiasDecoder = MakeDecoder<float>( + outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>()); + } + + // Initialize internal state tensors with zeroes. + if (!cifgEnabled) + { + ZeroVector(*inputGateEncoder, stateTensorSize); + } + ZeroVector(*forgetGateEncoder, stateTensorSize); + ZeroVector(*cellGateEncoder, stateTensorSize); + ZeroVector(*outputGateEncoder, stateTensorSize); + ZeroVector(*hiddenStateEncoder, stateTensorSize); + + // Input weights * Input + if (!cifgEnabled) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); + + // Recurrent weights * OutputStateIn + if (!cifgEnabled) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); + + // Input gate. + if (!cifgEnabled) + { + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, + numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); + } + + if (layerNormEnabled) + { + inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); + + MeanStddevNormalization(*inputGateDecoder, + *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); + + VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, + numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); + + inputGateInfo.SetQuantizationScale(1.f / 4096); + inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); + + VectorBatchVectorAdd(*inputGateBiasDecoder, + numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); + + inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); + } + + inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); + + // Input gate sigmoid + Activation(*inputGateDecoder, *inputGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); + } + + // Forget gate + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, + *cellStateInDecoder, numBatches, *forgetGateEncoder); + } + + if (layerNormEnabled) + { + // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 + forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); + + + + MeanStddevNormalization(*forgetGateDecoder, + *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + + forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); + + VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, + numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); + + + // Dequantize layer norm output to (1 / 4096) + forgetGateInfo.SetQuantizationScale(1.f / 4096); + forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); + + VectorBatchVectorAdd(*forgetGateBiasDecoder, + numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); + + + forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); + } + + forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); + + // Forget gate sigmoid + Activation(*forgetGateDecoder, *forgetGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); + + // Cell (Modulation) gate + if (layerNormEnabled) + { + cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); + + MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); + + VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, + numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); + + cellGateInfo.SetQuantizationScale(1.f / 4096); + cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); + + VectorBatchVectorAdd(*cellGateBiasDecoder, + numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); + + cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); + } + + cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); + + // Cell (Modulation) gate tanH + Activation(*cellGateDecoder, *cellGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::TanH, 1.0f, 1.0f); + + cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); + + VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); + + if (cifgEnabled) + { + Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); + VectorVectorCwiseProductAccumulate( + *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); + } + + // Final cell state out calculated here + if (m_Data.m_Parameters.m_CellClip > 0.0) + { + ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); + } + + // Output gate. + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, + numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); + } + + if (layerNormEnabled) + { + outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); + + MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); + + VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, + numBatches, *outputGateEncoder); + + outputGateInfo.SetQuantizationScale(1.f / 4096); + outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); + + VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); + + outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); + } + + outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); + + // Output gate sigmoid + Activation(*outputGateDecoder, *outputGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); + + // Hidden state tanH + Activation(*cellStateOutDecoder, *cellGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::TanH, 1.0f, 1.0f); + + // Final hidden state output + VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); + + // Projection + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + if (m_ProjectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasDecoder, + outputSize, numBatches, *outputEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, + outputSize, numUnits, *hiddenStateDecoder, numBatches, *outputEncoder); + + if (m_Data.m_Parameters.m_ProjectionClip > 0.0) + { + ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); + } + } + else + { + // Output has same quantization scale as hidden state if projection is disabled + CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); + } + + // output == outputStateOut + CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefQLstmWorkload.hpp b/src/backends/reference/workloads/RefQLstmWorkload.hpp new file mode 100644 index 0000000000..19d3a2af0f --- /dev/null +++ b/src/backends/reference/workloads/RefQLstmWorkload.hpp @@ -0,0 +1,54 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/TypesUtils.hpp> + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +namespace armnn +{ + +class RefQLstmWorkload : public BaseWorkload<QLstmQueueDescriptor> +{ +public: + explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); + + virtual void Execute() const override; + +private: + std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeightsTensor; + + std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeightsTensor; + + std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_CellToForgetWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_CellToOutputWeightsTensor; + + std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBiasTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBiasTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_CellBiasTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBiasTensor; + + std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionBiasTensor; + + std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeightsTensor; + std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeightsTensor; + + float m_LayerNormEpsilon = static_cast<float>(1e-8); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index cbfade3c02..e396a6ba3c 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -48,6 +48,7 @@ #include "RefPermuteWorkload.hpp" #include "RefPadWorkload.hpp" #include "RefPreluWorkload.hpp" +#include "RefQLstmWorkload.hpp" #include "RefQuantizeWorkload.hpp" #include "RefReshapeWorkload.hpp" #include "RefResizeBilinearWorkload.hpp" |