diff options
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 39 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/Activation.cpp | 20 | ||||
-rw-r--r-- | src/backends/reference/workloads/Activation.hpp | 7 | ||||
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 30 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 5 | ||||
-rw-r--r-- | src/backends/reference/workloads/LstmUtils.hpp | 218 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefLstmFloat32Workload.cpp | 379 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefLstmWorkload.cpp | 307 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefLstmWorkload.hpp (renamed from src/backends/reference/workloads/RefLstmFloat32Workload.hpp) | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 2 |
13 files changed, 603 insertions, 424 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ca9d7d9c5e..61e0d4004d 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -860,6 +860,18 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "LstmQueueDescriptor", 2, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "LstmQueueDescriptor", 2, "output"); + + std::vector<DataType> supportedTypes = { + DataType::Float32, + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "LstmQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "LstmQueueDescriptor"); } void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1b1f0ce1c6..67c13c3f84 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -594,12 +594,6 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo* cellToOutputWeights, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(outputStateIn); - ignore_unused(cellStateIn); - ignore_unused(scratchBuffer); - ignore_unused(outputStateOut); - ignore_unused(cellStateOut); - ignore_unused(output); ignore_unused(descriptor); ignore_unused(inputToForgetWeights); ignore_unused(inputToCellWeights); @@ -618,10 +612,35 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, ignore_unused(projectionBias); ignore_unused(cellToForgetWeights); ignore_unused(cellToOutputWeights); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); + + bool supported = true; + + std::array<DataType,2> supportedTypes = { + DataType::Float32 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference Lstm: input is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, + "Reference Lstm: input and outputStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, + "Reference Lstm: input and cellStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, + "Reference Lstm: input and scratchBuffer types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, + "Reference Lstm: input and outputStateOut types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, + "Reference Lstm: input and cellStateOut types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference Lstm: input and output types are mismatched"); + + return supported; } bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 8887bb719a..6603aaf27b 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -274,7 +274,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescr std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info); + return std::make_unique<RefLstmWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32( diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 06459ed31b..5034c0fe9e 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -47,7 +47,7 @@ BACKEND_SOURCES := \ workloads/RefFullyConnectedUint8Workload.cpp \ workloads/RefGatherWorkload.cpp \ workloads/RefL2NormalizationFloat32Workload.cpp \ - workloads/RefLstmFloat32Workload.cpp \ + workloads/RefLstmWorkload.cpp \ workloads/RefMeanFloat32Workload.cpp \ workloads/RefMeanUint8Workload.cpp \ workloads/RefMergerFloat32Workload.cpp \ diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp index 760c9a0ccd..2b0c84e226 100644 --- a/src/backends/reference/workloads/Activation.cpp +++ b/src/backends/reference/workloads/Activation.cpp @@ -88,26 +88,16 @@ void Activation(Decoder<float>& in, float a, float b) { - for (size_t i = 0; i<tensorInfo.GetNumElements(); i++) + unsigned int numElements = tensorInfo.GetNumElements(); + + for (unsigned int i = 0; i < numElements; i++) { out.Set(Activation(in.Get(), function, a, b)); - ++in; ++out; } -} - -void Activation(const float* in, - float* out, - const TensorInfo& tensorInfo, - ActivationFunction function, - float a, - float b) -{ - for (size_t i = 0; i<tensorInfo.GetNumElements(); i++) - { - out[i] = Activation(in[i], function, a, b); - } + in -= numElements; + out -= numElements; } } //namespace armnn diff --git a/src/backends/reference/workloads/Activation.hpp b/src/backends/reference/workloads/Activation.hpp index ffe3c5fc5d..b7fd50c54c 100644 --- a/src/backends/reference/workloads/Activation.hpp +++ b/src/backends/reference/workloads/Activation.hpp @@ -22,11 +22,4 @@ void Activation(Decoder<float>& in, float a, float b); -// This is still used by Reference LSTM implementation -void Activation(const float* in, - float* out, - const TensorInfo& tensorInfo, - ActivationFunction function, - float a, - float b); } //namespace armnn diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 3439e41b48..97af95a0eb 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -29,8 +29,6 @@ template<typename IType> class Decoder : public BaseIterator { public: - using InterfaceType = IType; - Decoder() {} virtual ~Decoder() {} @@ -42,13 +40,13 @@ template<typename IType> class Encoder : public BaseIterator { public: - using InterfaceType = IType; - Encoder() {} virtual ~Encoder() {} virtual void Set(IType right) = 0; + + virtual IType Get() const = 0; }; template<typename T, typename Base> @@ -77,6 +75,7 @@ public: return *this; } +protected: T* m_Iterator; }; @@ -135,6 +134,11 @@ public: *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset); } + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + private: const float m_Scale; const int32_t m_Offset; @@ -151,6 +155,11 @@ public: *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset); } + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + private: const float m_Scale; const int32_t m_Offset; @@ -166,6 +175,11 @@ public: { *m_Iterator = right; } + + float Get() const override + { + return *m_Iterator; + } }; class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>> @@ -178,7 +192,11 @@ public: { *m_Iterator = right; } -}; + bool Get() const override + { + return *m_Iterator; + } +}; -} //namespace armnn
\ No newline at end of file +} //namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 596c0993e0..b1cdef9cf1 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -26,6 +26,7 @@ list(APPEND armnnRefBackendWorkloads_sources FullyConnected.hpp Gather.cpp Gather.hpp + LstmUtils.hpp Maximum.hpp Merger.hpp Merger.cpp @@ -80,8 +81,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefGatherWorkload.hpp RefL2NormalizationFloat32Workload.cpp RefL2NormalizationFloat32Workload.hpp - RefLstmFloat32Workload.cpp - RefLstmFloat32Workload.hpp + RefLstmWorkload.cpp + RefLstmWorkload.hpp RefMergerFloat32Workload.cpp RefMergerFloat32Workload.hpp RefMergerUint8Workload.cpp diff --git a/src/backends/reference/workloads/LstmUtils.hpp b/src/backends/reference/workloads/LstmUtils.hpp new file mode 100644 index 0000000000..db02a84a45 --- /dev/null +++ b/src/backends/reference/workloads/LstmUtils.hpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" +#include <backendsCommon/CpuTensorHandle.hpp> + +namespace +{ + +// Helper functions ported from the Android code base +// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc + +void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix, + uint32_t mRows, + uint32_t mCols, + armnn::Decoder<float>& vector, + uint32_t nBatch, + armnn::Encoder<float>& outResult) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t r = 0; r < mRows; r++) + { + vector += b * mCols; + for (uint32_t c = 0; c < mCols; c++) + { + outResult.Set(outResult.Get() + matrix.Get() * vector.Get()); + ++matrix; + ++vector; + } + outResult += 1; + vector -= (b+1) * mCols; + } + matrix -= (mRows * mCols); + } + outResult -= (mRows * nBatch); +} + +void VectorBatchVectorAssign(armnn::Decoder<float>& vector, + uint32_t vSize, + uint32_t nBatch, + armnn::Encoder<float>& outBatchVector) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t v = 0; v < vSize; v++) + { + outBatchVector.Set(vector.Get()); + ++outBatchVector; + ++vector; + } + vector -= vSize; + } + outBatchVector -= (nBatch * vSize); +} + +void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector, + uint32_t vSize, + armnn::Decoder<float>& batchVector, + uint32_t nBatch, + armnn::Encoder<float>& outResult) +{ + for (uint32_t b = 0; b < nBatch; b++) + { + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(outResult.Get() + vector.Get() * batchVector.Get()); + ++outResult; + ++vector; + ++batchVector; + } + vector -= vSize; + } + batchVector -= vSize * nBatch; + outResult -= vSize * nBatch; +} + +void Sub1Vector(armnn::Decoder<float>& vector, + uint32_t vSize, + armnn::Encoder<float>& result) +{ + for (uint32_t v = 0; v < vSize; v++) + { + result.Set(1.0f - vector.Get()); + ++vector; + ++result; + } + vector -= vSize; + result -= vSize; +} + +void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1, + armnn::Decoder<float>& vector2, + uint32_t vSize, + armnn::Encoder<float>& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(vector1.Get() * vector2.Get()); + ++outResult; + ++vector1; + ++vector2; + } + outResult -= vSize; + vector1 -= vSize; + vector2 -= vSize; +} + +void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1, + armnn::Decoder<float>& vector2, + uint32_t vSize, + armnn::Encoder<float>& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(outResult.Get() + vector1.Get() * vector2.Get()); + ++outResult; + ++vector1; + ++vector2; + } + outResult -= vSize; + vector1 -= vSize; + vector2 -= vSize; +} + +float Clip(float f, + float absLimit) +{ + float result = (absLimit < f) ? absLimit : f; + result = (-absLimit > result) ? -absLimit : result; + return result; +} + +void ClipVector(armnn::Decoder<float>& vector, + uint32_t vSize, + float absLimit, + armnn::Encoder<float>& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(Clip(vector.Get(), absLimit)); + ++vector; + ++outResult; + } + vector -= vSize; + outResult -= vSize; +} + +void CopyVector(armnn::Decoder<float>& vector, + uint32_t vSize, + armnn::Encoder<float>& outResult) +{ + for (uint32_t v = 0; v < vSize; v++) + { + outResult.Set(vector.Get()); + ++outResult; + ++vector; + } + outResult -= vSize; + vector -= vSize; +} + +void SetActivationParameters(uint32_t activation, + armnn::ActivationFunction& outArmnnActivation, + float& outA, + float& outB) +{ + switch (activation) + { + case 0: // None + outA = 0; + outB = 0; + return; + + case 1: // Relu + outArmnnActivation = armnn::ActivationFunction::ReLu; + outA = 0; + outB = 0; + return; + + case 3: // Relu6 + outArmnnActivation = armnn::ActivationFunction::BoundedReLu; + outA = 6; + outB = 0; + return; + + case 4: // Tanh + outArmnnActivation = armnn::ActivationFunction::TanH; + outA = 1; + outB = 1; + return; + + case 6: // Sigmoid + outArmnnActivation = armnn::ActivationFunction::Sigmoid; + outA = 0; + outB = 0; + return; + + default: + throw armnn::Exception("Unsupported activation function: " + std::to_string(activation)); + } +} + +std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr) +{ + if (!ptr) + { + return nullptr; + } + + return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr); +} + +} // anonymous namespace diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp b/src/backends/reference/workloads/RefLstmFloat32Workload.cpp deleted file mode 100644 index c697b66658..0000000000 --- a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp +++ /dev/null @@ -1,379 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefLstmFloat32Workload.hpp" -#include "RefWorkloadUtils.hpp" -#include "Activation.hpp" - -namespace -{ - -// Helper functions ported from the Android code base -// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc - -void MatrixBatchVectorMultiplyAccumulate(const float* matrix, - uint32_t mRows, - uint32_t mCols, - const float* vector, - uint32_t nBatch, - float* outResult, - int resultStride = 1) -{ - float* resultInBatch = outResult; - for (uint32_t b = 0; b < nBatch; b++) - { - const float* matrixPtr = matrix; - for (uint32_t r = 0; r < mRows; r++) - { - const float* vectorInBatch = vector + b * mCols; - for (uint32_t c = 0; c < mCols; c++) - { - *resultInBatch += *matrixPtr++ * *vectorInBatch++; - } - resultInBatch += resultStride; - } - } -} - -void VectorBatchVectorAssign(const float* vector, - uint32_t vSize, - uint32_t nBatch, - float* outBatchVector) -{ - for (uint32_t b = 0; b < nBatch; b++) - { - memcpy(outBatchVector + b * vSize, vector, vSize * sizeof(float)); - } -} - -void VectorBatchVectorCwiseProductAccumulate(const float* vector, - uint32_t vSize, - const float* batchVector, - uint32_t nBatch, - float* outResult) -{ - for (uint32_t b = 0; b < nBatch; b++) - { - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ += vector[v] * *batchVector++; - } - } -} - -void Sub1Vector(const float* vector, - uint32_t vSize, - float* result) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *result++ = 1.0f - *vector++; - } -} - -void VectorVectorCwiseProduct(const float* vector1, - const float* vector2, - uint32_t vSize, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ = *vector1++ * *vector2++; - } -} - -void VectorVectorCwiseProductAccumulate(const float* vector1, - const float* vector2, - uint32_t vSize, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ += *vector1++ * *vector2++; - } -} - -float Clip(float f, - float absLimit) -{ - float result = (absLimit < f) ? absLimit : f; - result = (-absLimit > result) ? -absLimit : result; - return result; -} - -void ClipVector(const float* vector, - uint32_t vSize, - float absLimit, - float* outResult) -{ - for (uint32_t v = 0; v < vSize; v++) - { - *outResult++ = Clip(*vector++, absLimit); - } -} - -void CopyVector(const float* vector, - uint32_t vSize, - float* outResult) -{ - memcpy(outResult, vector, vSize * sizeof(float)); -} - -void SetActivationParameters(uint32_t activation, - armnn::ActivationFunction& outArmnnActivation, - float& outA, - float& outB) -{ - switch (activation) - { - case 0: // None - outA = 0; - outB = 0; - return; - - case 1: // Relu - outArmnnActivation = armnn::ActivationFunction::ReLu; - outA = 0; - outB = 0; - return; - - case 3: // Relu6 - outArmnnActivation = armnn::ActivationFunction::BoundedReLu; - outA = 6; - outB = 0; - return; - - case 4: // Tanh - outArmnnActivation = armnn::ActivationFunction::TanH; - outA = 1; - outB = 1; - return; - - case 6: // Sigmoid - outArmnnActivation = armnn::ActivationFunction::Sigmoid; - outA = 0; - outB = 0; - return; - - default: - throw armnn::Exception("Unsupported activation function: " + std::to_string(activation)); - } -} - -std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr) -{ - if (!ptr) - { - return nullptr; - } - - return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr); -} - -} // anonymous namespace - -namespace armnn -{ - -RefLstmFloat32Workload::RefLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) - : Float32Workload<LstmQueueDescriptor>(descriptor, info) - , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) - , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) - , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) - , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) - , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) - , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) - , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) - , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) - , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) - , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) - , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) - , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) - , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) - , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) - , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) - , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) - , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) -{} - -void RefLstmFloat32Workload::Execute() const -{ - // This is a porting of the LSTM::Eval() method in the Android code base - // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp - - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorShape& inputShape = inputInfo.GetShape(); - - float* scratchBuffer = GetOutputTensorDataFloat(0, m_Data); - float* outputStateOut = GetOutputTensorDataFloat(1, m_Data); - float* cellStateOut = GetOutputTensorDataFloat(2, m_Data); - float* output = GetOutputTensorDataFloat(3, m_Data); - - const float* inputData = GetInputTensorDataFloat(0, m_Data); - const float* outputStateIn = GetInputTensorDataFloat(1, m_Data); - const float* cellStateIn = GetInputTensorDataFloat(2, m_Data); - - const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; - - const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; - const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; - - // Index the scratch buffers pointers to the global scratch buffer. - float* inputGateScratch = nullptr; - float* cellScratch = nullptr; - float* forgetGateScratch = nullptr; - float* outputGateScratch = nullptr; - - if (useCifg) - { - cellScratch = scratchBuffer + 0 * nCell * nBatch; - forgetGateScratch = scratchBuffer + 1 * nCell * nBatch; - outputGateScratch = scratchBuffer + 2 * nCell * nBatch; - } - else - { - inputGateScratch = scratchBuffer + 0 * nCell * nBatch; - cellScratch = scratchBuffer + 1 * nCell * nBatch; - forgetGateScratch = scratchBuffer + 2 * nCell * nBatch; - outputGateScratch = scratchBuffer + 3 * nCell * nBatch; - } - - // Initialize scratch buffers with bias. - if (!useCifg) - { - VectorBatchVectorAssign(m_InputGateBiasTensor->GetTensor<float>(), - nCell, nBatch, inputGateScratch); - } - VectorBatchVectorAssign(m_ForgetGateBiasTensor->GetTensor<float>(), - nCell, nBatch, forgetGateScratch); - VectorBatchVectorAssign(m_CellBiasTensor->GetTensor<float>(), - nCell, nBatch, cellScratch); - VectorBatchVectorAssign(m_OutputGateBiasTensor->GetTensor<float>(), - nCell, nBatch, outputGateScratch); - - // For each batch and cell: compute input_weight * input. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(m_InputToInputWeightsTensor->GetTensor<float>(), - nCell, nInput, inputData, nBatch, inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(m_InputToForgetWeightsTensor->GetTensor<float>(), - nCell, nInput, inputData, nBatch, forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(m_InputToCellWeightsTensor->GetTensor<float>(), - nCell, nInput, inputData, nBatch, cellScratch); - MatrixBatchVectorMultiplyAccumulate(m_InputToOutputWeightsTensor->GetTensor<float>(), - nCell, nInput, inputData, nBatch, outputGateScratch); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!useCifg) - { - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToInputWeightsTensor->GetTensor<float>(), - nCell, nOutput, outputStateIn, nBatch, inputGateScratch); - } - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToForgetWeightsTensor->GetTensor<float>(), - nCell, nOutput, outputStateIn, nBatch, forgetGateScratch); - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToCellWeightsTensor->GetTensor<float>(), - nCell, nOutput, outputStateIn, nBatch, cellScratch); - MatrixBatchVectorMultiplyAccumulate(m_RecurrentToOutputWeightsTensor->GetTensor<float>(), - nCell, nOutput, outputStateIn, nBatch, outputGateScratch); - - // For each batch and cell: update input gate. - if (!useCifg) - { - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToInputWeightsTensor->GetTensor<float>(), - nCell, cellStateIn, nBatch, inputGateScratch); - } - Activation(inputGateScratch, inputGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - } - - // For each batch and cell: update forget gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToForgetWeightsTensor->GetTensor<float>(), nCell, - cellStateIn, nBatch, forgetGateScratch); - } - Activation(forgetGateScratch, forgetGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - - // For each batch and cell: update the cell. - VectorVectorCwiseProduct(forgetGateScratch, cellStateIn, nBatch * nCell, cellStateOut); - - ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; - float a = 0; - float b = 0; - SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(cellScratch, cellScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(forgetGateScratch, nBatch * nCell, forgetGateScratch); - VectorVectorCwiseProductAccumulate(cellScratch, forgetGateScratch, nBatch * nCell, cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate(cellScratch, inputGateScratch, nBatch * nCell, cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(cellStateOut, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(m_CellToOutputWeightsTensor->GetTensor<float>(), - nCell, cellStateOut, nBatch, outputGateScratch); - } - Activation(outputGateScratch, outputGateScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(cellStateOut, cellScratch, - TensorInfo({nCell, nBatch}, DataType::Float32), - armnnActivationFunc, a, b); - } - VectorVectorCwiseProduct(outputGateScratch, cellScratch, nBatch * nCell, outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(m_ProjectionBiasTensor->GetTensor<float>(), - nOutput, nBatch, output); - } - MatrixBatchVectorMultiplyAccumulate(m_ProjectionWeightsTensor->GetTensor<float>(), - nOutput, nCell, outputGateScratch, nBatch, output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(output, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, output); - } - } - else - { - CopyVector(outputGateScratch, nBatch * nOutput, output); - } - - CopyVector(output, nBatch * nOutput, outputStateOut); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp new file mode 100644 index 0000000000..f8ebc58f6e --- /dev/null +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload<LstmQueueDescriptor>(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) +{} + +void RefLstmWorkload::Execute() const +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); + std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); + + std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map()); + std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map()); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; + const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + + // Index the scratch buffers pointers to the global scratch buffer. + std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + + std::unique_ptr<Decoder<float>> inputGateScratchDecoder = + MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Decoder<float>> cellScratchDecoder = + MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = + MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Decoder<float>> outputGateScratchDecoder = + MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); + + if (useCifg) + { + *cellScratch += (0 * nCell * nBatch); + *forgetGateScratch += (1 * nCell * nBatch); + *outputGateScratch += (2 * nCell * nBatch); + + *cellScratchDecoder += (0 * nCell * nBatch); + *forgetGateScratchDecoder += (1 * nCell * nBatch); + *outputGateScratchDecoder += (2 * nCell * nBatch); + } + else + { + *inputGateScratch += (0 * nCell * nBatch); + *cellScratch += (1 * nCell * nBatch); + *forgetGateScratch += (2 * nCell * nBatch); + *outputGateScratch += (3 * nCell * nBatch); + + *inputGateScratchDecoder += (0 * nCell * nBatch); + *cellScratchDecoder += (1 * nCell * nBatch); + *forgetGateScratchDecoder += (2 * nCell * nBatch); + *outputGateScratchDecoder += (3 * nCell * nBatch); + } + + std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; + std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>()); + + std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; + std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>()); + + std::unique_ptr<Decoder<float>> inputGateBiasTensor; + std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>()); + std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>()); + + std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; + std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; + std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; + + std::unique_ptr<Decoder<float>> projectionWeightsTensor; + std::unique_ptr<Decoder<float>> projectionBiasTensor; + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder<float>( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>()); + inputGateBiasTensor = MakeDecoder<float>( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>()); + recurrentToInputWeightsTensor = MakeDecoder<float>( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder<float>( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>()); + cellToOutputWeightsTensor = MakeDecoder<float>( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder<float>( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder<float>( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder<float>( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>()); + } + } + + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (m_Data.m_Parameters.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + if (m_ProjectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp index a2dead8b9c..38e3fb956c 100644 --- a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefLstmWorkload.hpp @@ -13,10 +13,10 @@ namespace armnn { -class RefLstmFloat32Workload : public Float32Workload<LstmQueueDescriptor> +class RefLstmWorkload : public BaseWorkload<LstmQueueDescriptor> { public: - explicit RefLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); + explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); virtual void Execute() const override; diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 7871a1b806..8ffd3485ae 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -51,7 +51,7 @@ #include "Pooling2d.hpp" #include "RefFakeQuantizationFloat32Workload.hpp" #include "RefPermuteWorkload.hpp" -#include "RefLstmFloat32Workload.hpp" +#include "RefLstmWorkload.hpp" #include "RefConvertFp16ToFp32Workload.hpp" #include "RefConvertFp32ToFp16Workload.hpp" #include "RefMeanUint8Workload.hpp" |