From eb2b329b761ce3206505ed8d2eab071a2f97d5e7 Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Tue, 7 May 2019 12:02:30 +0100 Subject: IVGCVSW-2997 Refactor reference LSTM workload Signed-off-by: Nattapat Chaimanowong Change-Id: I6883f878d9f701a55153292769d2fc0530d2529e --- src/backends/backendsCommon/WorkloadData.cpp | 12 + src/backends/reference/RefLayerSupport.cpp | 39 ++- src/backends/reference/RefWorkloadFactory.cpp | 2 +- src/backends/reference/backend.mk | 2 +- src/backends/reference/workloads/Activation.cpp | 20 +- src/backends/reference/workloads/Activation.hpp | 7 - src/backends/reference/workloads/BaseIterator.hpp | 30 +- src/backends/reference/workloads/CMakeLists.txt | 5 +- src/backends/reference/workloads/LstmUtils.hpp | 218 ++++++++++++ .../reference/workloads/RefLstmFloat32Workload.cpp | 379 --------------------- .../reference/workloads/RefLstmFloat32Workload.hpp | 43 --- .../reference/workloads/RefLstmWorkload.cpp | 307 +++++++++++++++++ .../reference/workloads/RefLstmWorkload.hpp | 43 +++ src/backends/reference/workloads/RefWorkloads.hpp | 2 +- 14 files changed, 644 insertions(+), 465 deletions(-) create mode 100644 src/backends/reference/workloads/LstmUtils.hpp delete mode 100644 src/backends/reference/workloads/RefLstmFloat32Workload.cpp delete mode 100644 src/backends/reference/workloads/RefLstmFloat32Workload.hpp create mode 100644 src/backends/reference/workloads/RefLstmWorkload.cpp create mode 100644 src/backends/reference/workloads/RefLstmWorkload.hpp diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ca9d7d9c5e..61e0d4004d 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -860,6 +860,18 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "LstmQueueDescriptor", 2, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "LstmQueueDescriptor", 2, "output"); + + std::vector supportedTypes = { + DataType::Float32, + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "LstmQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "LstmQueueDescriptor"); } void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1b1f0ce1c6..67c13c3f84 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -594,12 +594,6 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo* cellToOutputWeights, Optional reasonIfUnsupported) const { - ignore_unused(outputStateIn); - ignore_unused(cellStateIn); - ignore_unused(scratchBuffer); - ignore_unused(outputStateOut); - ignore_unused(cellStateOut); - ignore_unused(output); ignore_unused(descriptor); ignore_unused(inputToForgetWeights); ignore_unused(inputToCellWeights); @@ -618,10 +612,35 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, ignore_unused(projectionBias); ignore_unused(cellToForgetWeights); ignore_unused(cellToOutputWeights); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); + + bool supported = true; + + std::array supportedTypes = { + DataType::Float32 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference Lstm: input is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, + "Reference Lstm: input and outputStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, + "Reference Lstm: input and cellStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, + "Reference Lstm: input and scratchBuffer types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, + "Reference Lstm: input and outputStateOut types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, + "Reference Lstm: input and cellStateOut types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference Lstm: input and output types are mismatched"); + + return supported; } bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 8887bb719a..6603aaf27b 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -274,7 +274,7 @@ std::unique_ptr RefWorkloadFactory::CreateFloor(const FloorQueueDescr std::unique_ptr RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateConvertFp16ToFp32( diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 06459ed31b..5034c0fe9e 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -47,7 +47,7 @@ BACKEND_SOURCES := \ workloads/RefFullyConnectedUint8Workload.cpp \ workloads/RefGatherWorkload.cpp \ workloads/RefL2NormalizationFloat32Workload.cpp \ - workloads/RefLstmFloat32Workload.cpp \ + workloads/RefLstmWorkload.cpp \ workloads/RefMeanFloat32Workload.cpp \ workloads/RefMeanUint8Workload.cpp \ workloads/RefMergerFloat32Workload.cpp \ diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp index 760c9a0ccd..2b0c84e226 100644 --- a/src/backends/reference/workloads/Activation.cpp +++ b/src/backends/reference/workloads/Activation.cpp @@ -88,26 +88,16 @@ void Activation(Decoder& in, float a, float b) { - for (size_t i = 0; i& in, float a, float b); -// This is still used by Reference LSTM implementation -void Activation(const float* in, - float* out, - const TensorInfo& tensorInfo, - ActivationFunction function, - float a, - float b); } //namespace armnn diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 3439e41b48..97af95a0eb 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -29,8 +29,6 @@ template class Decoder : public BaseIterator { public: - using InterfaceType = IType; - Decoder() {} virtual ~Decoder() {} @@ -42,13 +40,13 @@ template class Encoder : public BaseIterator { public: - using InterfaceType = IType; - Encoder() {} virtual ~Encoder() {} virtual void Set(IType right) = 0; + + virtual IType Get() const = 0; }; template @@ -77,6 +75,7 @@ public: return *this; } +protected: T* m_Iterator; }; @@ -135,6 +134,11 @@ public: *m_Iterator = armnn::Quantize(right, m_Scale, m_Offset); } + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + private: const float m_Scale; const int32_t m_Offset; @@ -151,6 +155,11 @@ public: *m_Iterator = armnn::Quantize(right, m_Scale, m_Offset); } + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + private: const float m_Scale; const int32_t m_Offset; @@ -166,6 +175,11 @@ public: { *m_Iterator = right; } + + float Get() const override + { + return *m_Iterator; + } }; class BooleanEncoder : public TypedIterator> @@ -178,7 +192,11 @@ public: { *m_Iterator = right; } -}; + bool Get() const override + { + return *m_Iterator; + } +}; -} //namespace armnn \ No newline at end of file +} //namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 596c0993e0..b1cdef9cf1 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -26,6 +26,7 @@ list(APPEND armnnRefBackendWorkloads_sources FullyConnected.hpp Gather.cpp Gather.hpp + LstmUtils.hpp Maximum.hpp Merger.hpp Merger.cpp @@ -80,8 +81,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefGatherWorkload.hpp RefL2NormalizationFloat32Workload.cpp RefL2NormalizationFloat32Workload.hpp - RefLstmFloat32Workload.cpp - RefLstmFloat32Workload.hpp + RefLstmWorkload.cpp + RefLstmWorkload.hpp RefMergerFloat32Workload.cpp RefMergerFloat32Workload.hpp RefMergerUint8Workload.cpp diff --git a/src/backends/reference/workloads/LstmUtils.hpp b/src/backends/reference/workloads/LstmUtils.hpp new file mode 100644 index 0000000000..db02a84a45 --- /dev/null +++ b/src/backends/reference/workloads/LstmUtils.hpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" +#include + +namespace +{ + +// Helper functions ported from the Android code base +// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc + +void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder& matrix, + uint32_t mRows, + uint32_t mCols, + armnn::Decoder& vector, + uint32_t nBatch, + armnn::Encoder& outResult) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t r = 0; r < mRows; r++) + { + vector += b * mCols; + for (uint32_t c = 0; c < mCols; c++) + { + outResult.Set(outResult.Get() + matrix.Get() * vector.Get()); + ++matrix; + ++vector; + } + outResult += 1; + vector -= (b+1) * mCols; + } + matrix -= (mRows * mCols); + } + outResult -= (mRows * nBatch); +} + +void VectorBatchVectorAssign(armnn::Decoder& vector, + uint32_t vSize, + uint32_t nBatch, + armnn::Encoder& outBatchVector) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t v = 0; v < vSize; v++) + { + outBatchVector.Set(vector.Get()); + ++outBatchVector; + ++vector; + } + vector -= vSize; + } + outBatchVector -= (nBatch * vSize); +} + +void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder& vector, + uint32_t vSize, + armnn::Decoder& batchVector, + uint32_t nBatch, + armnn::Encoder& outResult) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(outResult.Get() + vector.Get() * batchVector.Get()); + ++outResult; + ++vector; + ++batchVector; + } + vector -= vSize; + } + batchVector -= vSize * nBatch; + outResult -= vSize * nBatch; +} + +void Sub1Vector(armnn::Decoder& vector, + uint32_t vSize, + armnn::Encoder& result) +{ + for (uint32_t v = 0; v < vSize; v++) + { + result.Set(1.0f - vector.Get()); + ++vector; + ++result; + } + vector -= vSize; + result -= vSize; +} + +void VectorVectorCwiseProduct(armnn::Decoder& vector1, + armnn::Decoder& vector2, + uint32_t vSize, + armnn::Encoder& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(vector1.Get() * vector2.Get()); + ++outResult; + ++vector1; + ++vector2; + } + outResult -= vSize; + vector1 -= vSize; + vector2 -= vSize; +} + +void VectorVectorCwiseProductAccumulate(armnn::Decoder& vector1, + armnn::Decoder& vector2, + uint32_t vSize, + armnn::Encoder& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(outResult.Get() + vector1.Get() * vector2.Get()); + ++outResult; + ++vector1; + ++vector2; + } + outResult -= vSize; + vector1 -= vSize; + vector2 -= vSize; +} + +float Clip(float f, + float absLimit) +{ + float result = (absLimit < f) ? absLimit : f; + result = (-absLimit > result) ? -absLimit : result; + return result; +} + +void ClipVector(armnn::Decoder& vector, + uint32_t vSize, + float absLimit, + armnn::Encoder& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(Clip(vector.Get(), absLimit)); + ++vector; + ++outResult; + } + vector -= vSize; + outResult -= vSize; +} + +void CopyVector(armnn::Decoder& vector, + uint32_t vSize, + armnn::Encoder& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(vector.Get()); + ++outResult; + ++vector; + } + outResult -= vSize; + vector -= vSize; +} + +void SetActivationParameters(uint32_t activation, + armnn::ActivationFunction& outArmnnActivation, + float& outA, + float& outB) +{ + switch (activation) + { + case 0: // None + outA = 0; + outB = 0; + return; + + case 1: // Relu + outArmnnActivation = armnn::ActivationFunction::ReLu; + outA = 0; + outB = 0; + return; + + case 3: // Relu6 + outArmnnActivation = armnn::ActivationFunction::BoundedReLu; + outA = 6; + outB = 0; + return; + + case 4: // Tanh + outArmnnActivation = armnn::ActivationFunction::TanH; + outA = 1; + outB = 1; + return; + + case 6: // Sigmoid + outArmnnActivation = armnn::ActivationFunction::Sigmoid; + outA = 0; + outB = 0; + return; + + default: + throw armnn::Exception("Unsupported activation function: " + std::to_string(activation)); + } +} + +std::unique_ptr AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr) +{ + if (!ptr) + { + return nullptr; + } + + return std::make_unique(*ptr); +} + +} // anonymous namespace diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp b/src/backends/reference/workloads/RefLstmFloat32Workload.cpp deleted file mode 100644 index c697b66658..0000000000 --- a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp +++ /dev/null @@ -1,379 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefLstmFloat32Workload.hpp" -#include "RefWorkloadUtils.hpp" -#include "Activation.hpp" - -namespace -{ - -// Helper functions ported from the Android code base -// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc - -void MatrixBatchVectorMultiplyAccumulate(const float* matrix, - uint32_t mRows, - uint32_t mCols, - const float* vector, - uint32_t nBatch, - float* outResult, - int resultStride = 1) -{ - float* resultInBatch = outResult; - for (uint32_t b = 0; b < nBatch; b++) - { - const float* matrixPtr = matrix; - for (uint32_t r = 0; r < mRows; r++) - { - const float* vectorInBatch = vector + b * mCols; - for (uint32_t c = 0; c < mCols; c++) - { - *resultInBatch += *matrixPtr++ * *vectorInBatch++; - } - resultInBatch += resultStride; - } - } -} - -void VectorBatchVectorAssign(const float* vector, - uint32_t vSize, - uint32_t nBatch, - float* outBatchVector) -{ - for (uint32_t b = 0; b < nBatch; b++) - { - memcpy(outBatchVector + b * vSize, vector, vSize * sizeof(float)); - } -} - -void VectorBatchVectorCwiseProductAccumulate(const float* vector, - uint32_t vSize, - const float* batchVector, - uint32_t nBatch, - float* outResult) -{ - for (uint32_t b = 0; b < nBatch; b++) - { - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ += vector[v] * *batchVector++; - } - } -} - -void Sub1Vector(const float* vector, - uint32_t vSize, - float* result) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *result++ = 1.0f - *vector++; - } -} - -void VectorVectorCwiseProduct(const float* vector1, - const float* vector2, - uint32_t vSize, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ = *vector1++ * *vector2++; - } -} - -void VectorVectorCwiseProductAccumulate(const float* vector1, - const float* vector2, - uint32_t vSize, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ += *vector1++ * *vector2++; - } -} - -float Clip(float f, - float absLimit) -{ - float result = (absLimit < f) ? absLimit : f; - result = (-absLimit > result) ? -absLimit : result; - return result; -} - -void ClipVector(const float* vector, - uint32_t vSize, - float absLimit, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ = Clip(*vector++, absLimit); - } -} - -void CopyVector(const float* vector, - uint32_t vSize, - float* outResult) -{ - memcpy(outResult, vector, vSize * sizeof(float)); -} - -void SetActivationParameters(uint32_t activation, - armnn::ActivationFunction& outArmnnActivation, - float& outA, - float& outB) -{ - switch (activation) - { - case 0: // None - outA = 0; - outB = 0; - return; - - case 1: // Relu - outArmnnActivation = armnn::ActivationFunction::ReLu; - outA = 0; - outB = 0; - return; - - case 3: // Relu6 - outArmnnActivation = armnn::ActivationFunction::BoundedReLu; - outA = 6; - outB = 0; - return; - - case 4: // Tanh - outArmnnActivation = armnn::ActivationFunction::TanH; - outA = 1; - outB = 1; - return; - - case 6: // Sigmoid - outArmnnActivation = armnn::ActivationFunction::Sigmoid; - outA = 0; - outB = 0; - return; - - default: - throw armnn::Exception("Unsupported activation function: " + std::to_string(activation)); - } -} - -std::unique_ptr AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr) -{ - if (!ptr) - { - return nullptr; - } - - return std::make_unique(*ptr); -} - -} // anonymous namespace - -namespace armnn -{ - -RefLstmFloat32Workload::RefLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) - : Float32Workload(descriptor, info) - , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) - , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) - , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) - , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) - , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) - , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) - , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) - , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) - , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) - , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) - , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) - , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) - , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) - , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) - , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) - , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) - , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) -{} - -void RefLstmFloat32Workload::Execute() const -{ - // This is a porting of the LSTM::Eval() method in the Android code base - // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp - - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorShape& inputShape = inputInfo.GetShape(); - - float* scratchBuffer = GetOutputTensorDataFloat(0, m_Data); - float* outputStateOut = GetOutputTensorDataFloat(1, m_Data); - float* cellStateOut = GetOutputTensorDataFloat(2, m_Data); - float* output = GetOutputTensorDataFloat(3, m_Data); - - const float* inputData = GetInputTensorDataFloat(0, m_Data); - const float* outputStateIn = GetInputTensorDataFloat(1, m_Data); - const float* cellStateIn = GetInputTensorDataFloat(2, m_Data); - - const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; - - const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; - const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; - - // Index the scratch buffers pointers to the global scratch buffer. - float* inputGateScratch = nullptr; - float* cellScratch = nullptr; - float* forgetGateScratch = nullptr; - float* outputGateScratch = nullptr; - - if (useCifg) - { - cellScratch = scratchBuffer + 0 * nCell * nBatch; - forgetGateScratch = scratchBuffer + 1 * nCell * nBatch; - outputGateScratch = scratchBuffer + 2 * nCell * nBatch; - } - else - { - inputGateScratch = scratchBuffer + 0 * nCell * nBatch; - cellScratch = scratchBuffer + 1 * nCell * nBatch; - forgetGateScratch = scratchBuffer + 2 * nCell * nBatch; - outputGateScratch = scratchBuffer + 3 * nCell * nBatch; - } - - // Initialize scratch buffers with bias. - if (!useCifg) - { - VectorBatchVectorAssign(m_InputGateBiasTensor->GetTensor(), - nCell, nBatch, inputGateScratch); - } - VectorBatchVectorAssign(m_ForgetGateBiasTensor->GetTensor(), - nCell, nBatch, forgetGateScratch); - VectorBatchVectorAssign(m_CellBiasTensor->GetTensor(), - nCell, nBatch, cellScratch); - VectorBatchVectorAssign(m_OutputGateBiasTensor->GetTensor(), - nCell, nBatch, outputGateScratch); - - // For each batch and cell: compute input_weight * input. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(m_InputToInputWeightsTensor->GetTensor(), - nCell, nInput, inputData, nBatch, inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(m_InputToForgetWeightsTensor->GetTensor(), - nCell, nInput, inputData, nBatch, forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(m_InputToCellWeightsTensor->GetTensor(), - nCell, nInput, inputData, nBatch, cellScratch); - MatrixBatchVectorMultiplyAccumulate(m_InputToOutputWeightsTensor->GetTensor(), - nCell, nInput, inputData, nBatch, outputGateScratch); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToInputWeightsTensor->GetTensor(), - nCell, nOutput, outputStateIn, nBatch, inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToForgetWeightsTensor->GetTensor(), - nCell, nOutput, outputStateIn, nBatch, forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToCellWeightsTensor->GetTensor(), - nCell, nOutput, outputStateIn, nBatch, cellScratch); - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToOutputWeightsTensor->GetTensor(), - nCell, nOutput, outputStateIn, nBatch, outputGateScratch); - - // For each batch and cell: update input gate. - if (!useCifg) - { - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToInputWeightsTensor->GetTensor(), - nCell, cellStateIn, nBatch, inputGateScratch); - } - Activation(inputGateScratch, inputGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - } - - // For each batch and cell: update forget gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToForgetWeightsTensor->GetTensor(), nCell, - cellStateIn, nBatch, forgetGateScratch); - } - Activation(forgetGateScratch, forgetGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - - // For each batch and cell: update the cell. - VectorVectorCwiseProduct(forgetGateScratch, cellStateIn, nBatch * nCell, cellStateOut); - - ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; - float a = 0; - float b = 0; - SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(cellScratch, cellScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(forgetGateScratch, nBatch * nCell, forgetGateScratch); - VectorVectorCwiseProductAccumulate(cellScratch, forgetGateScratch, nBatch * nCell, cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate(cellScratch, inputGateScratch, nBatch * nCell, cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(cellStateOut, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToOutputWeightsTensor->GetTensor(), - nCell, cellStateOut, nBatch, outputGateScratch); - } - Activation(outputGateScratch, outputGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(cellStateOut, cellScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - armnnActivationFunc, a, b); - } - VectorVectorCwiseProduct(outputGateScratch, cellScratch, nBatch * nCell, outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(m_ProjectionBiasTensor->GetTensor(), - nOutput, nBatch, output); - } - MatrixBatchVectorMultiplyAccumulate(m_ProjectionWeightsTensor->GetTensor(), - nOutput, nCell, outputGateScratch, nBatch, output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(output, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, output); - } - } - else - { - CopyVector(outputGateScratch, nBatch * nOutput, output); - } - - CopyVector(output, nBatch * nOutput, outputStateOut); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp b/src/backends/reference/workloads/RefLstmFloat32Workload.hpp deleted file mode 100644 index a2dead8b9c..0000000000 --- a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp +++ /dev/null @@ -1,43 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -#include -#include - -namespace armnn -{ - -class RefLstmFloat32Workload : public Float32Workload -{ -public: - explicit RefLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); - - virtual void Execute() const override; - -private: - std::unique_ptr m_InputToInputWeightsTensor; - std::unique_ptr m_InputToForgetWeightsTensor; - std::unique_ptr m_InputToCellWeightsTensor; - std::unique_ptr m_InputToOutputWeightsTensor; - std::unique_ptr m_RecurrentToInputWeightsTensor; - std::unique_ptr m_RecurrentToForgetWeightsTensor; - std::unique_ptr m_RecurrentToCellWeightsTensor; - std::unique_ptr m_RecurrentToOutputWeightsTensor; - std::unique_ptr m_CellToInputWeightsTensor; - std::unique_ptr m_CellToForgetWeightsTensor; - std::unique_ptr m_CellToOutputWeightsTensor; - std::unique_ptr m_InputGateBiasTensor; - std::unique_ptr m_ForgetGateBiasTensor; - std::unique_ptr m_CellBiasTensor; - std::unique_ptr m_OutputGateBiasTensor; - std::unique_ptr m_ProjectionWeightsTensor; - std::unique_ptr m_ProjectionBiasTensor; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp new file mode 100644 index 0000000000..f8ebc58f6e --- /dev/null +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) +{} + +void RefLstmWorkload::Execute() const +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + std::unique_ptr> outputStateOut = MakeEncoder(outputInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr> cellStateOut = MakeEncoder(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr> output = MakeEncoder(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr> cellStateOutDecoder = MakeDecoder(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr> outputDecoder = MakeDecoder(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr> inputData = MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputStateIn = MakeDecoder(inputInfo, m_Data.m_Inputs[1]->Map()); + std::unique_ptr> cellStateIn = MakeDecoder(inputInfo, m_Data.m_Inputs[2]->Map()); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; + const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + + // Index the scratch buffers pointers to the global scratch buffer. + std::unique_ptr> inputGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> forgetGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> outputGateScratch = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + std::unique_ptr> inputGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> forgetGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> outputGateScratchDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + if (useCifg) + { + *cellScratch += (0 * nCell * nBatch); + *forgetGateScratch += (1 * nCell * nBatch); + *outputGateScratch += (2 * nCell * nBatch); + + *cellScratchDecoder += (0 * nCell * nBatch); + *forgetGateScratchDecoder += (1 * nCell * nBatch); + *outputGateScratchDecoder += (2 * nCell * nBatch); + } + else + { + *inputGateScratch += (0 * nCell * nBatch); + *cellScratch += (1 * nCell * nBatch); + *forgetGateScratch += (2 * nCell * nBatch); + *outputGateScratch += (3 * nCell * nBatch); + + *inputGateScratchDecoder += (0 * nCell * nBatch); + *cellScratchDecoder += (1 * nCell * nBatch); + *forgetGateScratchDecoder += (2 * nCell * nBatch); + *outputGateScratchDecoder += (3 * nCell * nBatch); + } + + std::unique_ptr> inputToInputWeightsTensor; + std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor()); + std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor()); + std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor()); + + std::unique_ptr> recurrentToInputWeightsTensor; + std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor()); + + std::unique_ptr> inputGateBiasTensor; + std::unique_ptr> forgetGateBiasTensor = MakeDecoder( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor()); + std::unique_ptr> cellBiasTensor = MakeDecoder( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor()); + std::unique_ptr> outputGateBiasTensor = MakeDecoder( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor()); + + std::unique_ptr> cellToInputWeightsTensor; + std::unique_ptr> cellToForgetWeightsTensor; + std::unique_ptr> cellToOutputWeightsTensor; + + std::unique_ptr> projectionWeightsTensor; + std::unique_ptr> projectionBiasTensor; + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor()); + inputGateBiasTensor = MakeDecoder( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor()); + recurrentToInputWeightsTensor = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor()); + cellToOutputWeightsTensor = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor()); + } + } + + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + if (m_ProjectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp new file mode 100644 index 0000000000..38e3fb956c --- /dev/null +++ b/src/backends/reference/workloads/RefLstmWorkload.hpp @@ -0,0 +1,43 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +namespace armnn +{ + +class RefLstmWorkload : public BaseWorkload +{ +public: + explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); + + virtual void Execute() const override; + +private: + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 7871a1b806..8ffd3485ae 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -51,7 +51,7 @@ #include "Pooling2d.hpp" #include "RefFakeQuantizationFloat32Workload.hpp" #include "RefPermuteWorkload.hpp" -#include "RefLstmFloat32Workload.hpp" +#include "RefLstmWorkload.hpp" #include "RefConvertFp16ToFp32Workload.hpp" #include "RefConvertFp32ToFp16Workload.hpp" #include "RefMeanUint8Workload.hpp" -- cgit v1.2.1