aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-05-07 12:02:30 +0100
committerderek.lamberti <derek.lamberti@arm.com>2019-05-07 12:15:38 +0000
commiteb2b329b761ce3206505ed8d2eab071a2f97d5e7 (patch)
treea236bea2077d391fdae90cdbee866f6754897623
parent5cf4d1c29a36bb1d675a7cbe2d24b688deb7d160 (diff)
downloadarmnn-eb2b329b761ce3206505ed8d2eab071a2f97d5e7.tar.gz
IVGCVSW-2997 Refactor reference LSTM workload
Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> Change-Id: I6883f878d9f701a55153292769d2fc0530d2529e
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp12
-rw-r--r--src/backends/reference/RefLayerSupport.cpp39
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/workloads/Activation.cpp20
-rw-r--r--src/backends/reference/workloads/Activation.hpp7
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp30
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt5
-rw-r--r--src/backends/reference/workloads/LstmUtils.hpp218
-rw-r--r--src/backends/reference/workloads/RefLstmFloat32Workload.cpp379
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.cpp307
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.hpp (renamed from src/backends/reference/workloads/RefLstmFloat32Workload.hpp)4
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
13 files changed, 603 insertions, 424 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ca9d7d9c5e..61e0d4004d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -860,6 +860,18 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "LstmQueueDescriptor", 2, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "LstmQueueDescriptor", 2, "output");
+
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "LstmQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "LstmQueueDescriptor");
}
void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 1b1f0ce1c6..67c13c3f84 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -594,12 +594,6 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo* cellToOutputWeights,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(outputStateIn);
- ignore_unused(cellStateIn);
- ignore_unused(scratchBuffer);
- ignore_unused(outputStateOut);
- ignore_unused(cellStateOut);
- ignore_unused(output);
ignore_unused(descriptor);
ignore_unused(inputToForgetWeights);
ignore_unused(inputToCellWeights);
@@ -618,10 +612,35 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
ignore_unused(projectionBias);
ignore_unused(cellToForgetWeights);
ignore_unused(cellToOutputWeights);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
+
+ bool supported = true;
+
+ std::array<DataType,2> supportedTypes = {
+ DataType::Float32
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference Lstm: input is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
+ "Reference Lstm: input and outputStateIn types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
+ "Reference Lstm: input and cellStateIn types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
+ "Reference Lstm: input and scratchBuffer types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
+ "Reference Lstm: input and outputStateOut types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
+ "Reference Lstm: input and cellStateOut types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference Lstm: input and output types are mismatched");
+
+ return supported;
}
bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 8887bb719a..6603aaf27b 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -274,7 +274,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescr
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
+ return std::make_unique<RefLstmWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 06459ed31b..5034c0fe9e 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -47,7 +47,7 @@ BACKEND_SOURCES := \
workloads/RefFullyConnectedUint8Workload.cpp \
workloads/RefGatherWorkload.cpp \
workloads/RefL2NormalizationFloat32Workload.cpp \
- workloads/RefLstmFloat32Workload.cpp \
+ workloads/RefLstmWorkload.cpp \
workloads/RefMeanFloat32Workload.cpp \
workloads/RefMeanUint8Workload.cpp \
workloads/RefMergerFloat32Workload.cpp \
diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp
index 760c9a0ccd..2b0c84e226 100644
--- a/src/backends/reference/workloads/Activation.cpp
+++ b/src/backends/reference/workloads/Activation.cpp
@@ -88,26 +88,16 @@ void Activation(Decoder<float>& in,
float a,
float b)
{
- for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
+ unsigned int numElements = tensorInfo.GetNumElements();
+
+ for (unsigned int i = 0; i < numElements; i++)
{
out.Set(Activation(in.Get(), function, a, b));
-
++in;
++out;
}
-}
-
-void Activation(const float* in,
- float* out,
- const TensorInfo& tensorInfo,
- ActivationFunction function,
- float a,
- float b)
-{
- for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
- {
- out[i] = Activation(in[i], function, a, b);
- }
+ in -= numElements;
+ out -= numElements;
}
} //namespace armnn
diff --git a/src/backends/reference/workloads/Activation.hpp b/src/backends/reference/workloads/Activation.hpp
index ffe3c5fc5d..b7fd50c54c 100644
--- a/src/backends/reference/workloads/Activation.hpp
+++ b/src/backends/reference/workloads/Activation.hpp
@@ -22,11 +22,4 @@ void Activation(Decoder<float>& in,
float a,
float b);
-// This is still used by Reference LSTM implementation
-void Activation(const float* in,
- float* out,
- const TensorInfo& tensorInfo,
- ActivationFunction function,
- float a,
- float b);
} //namespace armnn
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 3439e41b48..97af95a0eb 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -29,8 +29,6 @@ template<typename IType>
class Decoder : public BaseIterator
{
public:
- using InterfaceType = IType;
-
Decoder() {}
virtual ~Decoder() {}
@@ -42,13 +40,13 @@ template<typename IType>
class Encoder : public BaseIterator
{
public:
- using InterfaceType = IType;
-
Encoder() {}
virtual ~Encoder() {}
virtual void Set(IType right) = 0;
+
+ virtual IType Get() const = 0;
};
template<typename T, typename Base>
@@ -77,6 +75,7 @@ public:
return *this;
}
+protected:
T* m_Iterator;
};
@@ -135,6 +134,11 @@ public:
*m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
}
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+ }
+
private:
const float m_Scale;
const int32_t m_Offset;
@@ -151,6 +155,11 @@ public:
*m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
}
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+ }
+
private:
const float m_Scale;
const int32_t m_Offset;
@@ -166,6 +175,11 @@ public:
{
*m_Iterator = right;
}
+
+ float Get() const override
+ {
+ return *m_Iterator;
+ }
};
class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
@@ -178,7 +192,11 @@ public:
{
*m_Iterator = right;
}
-};
+ bool Get() const override
+ {
+ return *m_Iterator;
+ }
+};
-} //namespace armnn \ No newline at end of file
+} //namespace armnn
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 596c0993e0..b1cdef9cf1 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -26,6 +26,7 @@ list(APPEND armnnRefBackendWorkloads_sources
FullyConnected.hpp
Gather.cpp
Gather.hpp
+ LstmUtils.hpp
Maximum.hpp
Merger.hpp
Merger.cpp
@@ -80,8 +81,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefGatherWorkload.hpp
RefL2NormalizationFloat32Workload.cpp
RefL2NormalizationFloat32Workload.hpp
- RefLstmFloat32Workload.cpp
- RefLstmFloat32Workload.hpp
+ RefLstmWorkload.cpp
+ RefLstmWorkload.hpp
RefMergerFloat32Workload.cpp
RefMergerFloat32Workload.hpp
RefMergerUint8Workload.cpp
diff --git a/src/backends/reference/workloads/LstmUtils.hpp b/src/backends/reference/workloads/LstmUtils.hpp
new file mode 100644
index 0000000000..db02a84a45
--- /dev/null
+++ b/src/backends/reference/workloads/LstmUtils.hpp
@@ -0,0 +1,218 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+#include <backendsCommon/CpuTensorHandle.hpp>
+
+namespace
+{
+
+// Helper functions ported from the Android code base
+// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+
+void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix,
+ uint32_t mRows,
+ uint32_t mCols,
+ armnn::Decoder<float>& vector,
+ uint32_t nBatch,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t b = 0; b < nBatch; b++)
+ {
+ for (uint32_t r = 0; r < mRows; r++)
+ {
+ vector += b * mCols;
+ for (uint32_t c = 0; c < mCols; c++)
+ {
+ outResult.Set(outResult.Get() + matrix.Get() * vector.Get());
+ ++matrix;
+ ++vector;
+ }
+ outResult += 1;
+ vector -= (b+1) * mCols;
+ }
+ matrix -= (mRows * mCols);
+ }
+ outResult -= (mRows * nBatch);
+}
+
+void VectorBatchVectorAssign(armnn::Decoder<float>& vector,
+ uint32_t vSize,
+ uint32_t nBatch,
+ armnn::Encoder<float>& outBatchVector)
+{
+ for (uint32_t b = 0; b < nBatch; b++)
+ {
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outBatchVector.Set(vector.Get());
+ ++outBatchVector;
+ ++vector;
+ }
+ vector -= vSize;
+ }
+ outBatchVector -= (nBatch * vSize);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector,
+ uint32_t vSize,
+ armnn::Decoder<float>& batchVector,
+ uint32_t nBatch,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t b = 0; b < nBatch; b++)
+ {
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outResult.Set(outResult.Get() + vector.Get() * batchVector.Get());
+ ++outResult;
+ ++vector;
+ ++batchVector;
+ }
+ vector -= vSize;
+ }
+ batchVector -= vSize * nBatch;
+ outResult -= vSize * nBatch;
+}
+
+void Sub1Vector(armnn::Decoder<float>& vector,
+ uint32_t vSize,
+ armnn::Encoder<float>& result)
+{
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ result.Set(1.0f - vector.Get());
+ ++vector;
+ ++result;
+ }
+ vector -= vSize;
+ result -= vSize;
+}
+
+void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1,
+ armnn::Decoder<float>& vector2,
+ uint32_t vSize,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outResult.Set(vector1.Get() * vector2.Get());
+ ++outResult;
+ ++vector1;
+ ++vector2;
+ }
+ outResult -= vSize;
+ vector1 -= vSize;
+ vector2 -= vSize;
+}
+
+void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1,
+ armnn::Decoder<float>& vector2,
+ uint32_t vSize,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outResult.Set(outResult.Get() + vector1.Get() * vector2.Get());
+ ++outResult;
+ ++vector1;
+ ++vector2;
+ }
+ outResult -= vSize;
+ vector1 -= vSize;
+ vector2 -= vSize;
+}
+
+float Clip(float f,
+ float absLimit)
+{
+ float result = (absLimit < f) ? absLimit : f;
+ result = (-absLimit > result) ? -absLimit : result;
+ return result;
+}
+
+void ClipVector(armnn::Decoder<float>& vector,
+ uint32_t vSize,
+ float absLimit,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outResult.Set(Clip(vector.Get(), absLimit));
+ ++vector;
+ ++outResult;
+ }
+ vector -= vSize;
+ outResult -= vSize;
+}
+
+void CopyVector(armnn::Decoder<float>& vector,
+ uint32_t vSize,
+ armnn::Encoder<float>& outResult)
+{
+ for (uint32_t v = 0; v < vSize; v++)
+ {
+ outResult.Set(vector.Get());
+ ++outResult;
+ ++vector;
+ }
+ outResult -= vSize;
+ vector -= vSize;
+}
+
+void SetActivationParameters(uint32_t activation,
+ armnn::ActivationFunction& outArmnnActivation,
+ float& outA,
+ float& outB)
+{
+ switch (activation)
+ {
+ case 0: // None
+ outA = 0;
+ outB = 0;
+ return;
+
+ case 1: // Relu
+ outArmnnActivation = armnn::ActivationFunction::ReLu;
+ outA = 0;
+ outB = 0;
+ return;
+
+ case 3: // Relu6
+ outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
+ outA = 6;
+ outB = 0;
+ return;
+
+ case 4: // Tanh
+ outArmnnActivation = armnn::ActivationFunction::TanH;
+ outA = 1;
+ outB = 1;
+ return;
+
+ case 6: // Sigmoid
+ outArmnnActivation = armnn::ActivationFunction::Sigmoid;
+ outA = 0;
+ outB = 0;
+ return;
+
+ default:
+ throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
+ }
+}
+
+std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
+{
+ if (!ptr)
+ {
+ return nullptr;
+ }
+
+ return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
+}
+
+} // anonymous namespace
diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp b/src/backends/reference/workloads/RefLstmFloat32Workload.cpp
deleted file mode 100644
index c697b66658..0000000000
--- a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp
+++ /dev/null
@@ -1,379 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefLstmFloat32Workload.hpp"
-#include "RefWorkloadUtils.hpp"
-#include "Activation.hpp"
-
-namespace
-{
-
-// Helper functions ported from the Android code base
-// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
-
-void MatrixBatchVectorMultiplyAccumulate(const float* matrix,
- uint32_t mRows,
- uint32_t mCols,
- const float* vector,
- uint32_t nBatch,
- float* outResult,
- int resultStride = 1)
-{
- float* resultInBatch = outResult;
- for (uint32_t b = 0; b < nBatch; b++)
- {
- const float* matrixPtr = matrix;
- for (uint32_t r = 0; r < mRows; r++)
- {
- const float* vectorInBatch = vector + b * mCols;
- for (uint32_t c = 0; c < mCols; c++)
- {
- *resultInBatch += *matrixPtr++ * *vectorInBatch++;
- }
- resultInBatch += resultStride;
- }
- }
-}
-
-void VectorBatchVectorAssign(const float* vector,
- uint32_t vSize,
- uint32_t nBatch,
- float* outBatchVector)
-{
- for (uint32_t b = 0; b < nBatch; b++)
- {
- memcpy(outBatchVector + b * vSize, vector, vSize * sizeof(float));
- }
-}
-
-void VectorBatchVectorCwiseProductAccumulate(const float* vector,
- uint32_t vSize,
- const float* batchVector,
- uint32_t nBatch,
- float* outResult)
-{
- for (uint32_t b = 0; b < nBatch; b++)
- {
- for (uint32_t v = 0; v < vSize; v++)
- {
- *outResult++ += vector[v] * *batchVector++;
- }
- }
-}
-
-void Sub1Vector(const float* vector,
- uint32_t vSize,
- float* result)
-{
- for (uint32_t v = 0; v < vSize; v++)
- {
- *result++ = 1.0f - *vector++;
- }
-}
-
-void VectorVectorCwiseProduct(const float* vector1,
- const float* vector2,
- uint32_t vSize,
- float* outResult)
-{
- for (uint32_t v = 0; v < vSize; v++)
- {
- *outResult++ = *vector1++ * *vector2++;
- }
-}
-
-void VectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2,
- uint32_t vSize,
- float* outResult)
-{
- for (uint32_t v = 0; v < vSize; v++)
- {
- *outResult++ += *vector1++ * *vector2++;
- }
-}
-
-float Clip(float f,
- float absLimit)
-{
- float result = (absLimit < f) ? absLimit : f;
- result = (-absLimit > result) ? -absLimit : result;
- return result;
-}
-
-void ClipVector(const float* vector,
- uint32_t vSize,
- float absLimit,
- float* outResult)
-{
- for (uint32_t v = 0; v < vSize; v++)
- {
- *outResult++ = Clip(*vector++, absLimit);
- }
-}
-
-void CopyVector(const float* vector,
- uint32_t vSize,
- float* outResult)
-{
- memcpy(outResult, vector, vSize * sizeof(float));
-}
-
-void SetActivationParameters(uint32_t activation,
- armnn::ActivationFunction& outArmnnActivation,
- float& outA,
- float& outB)
-{
- switch (activation)
- {
- case 0: // None
- outA = 0;
- outB = 0;
- return;
-
- case 1: // Relu
- outArmnnActivation = armnn::ActivationFunction::ReLu;
- outA = 0;
- outB = 0;
- return;
-
- case 3: // Relu6
- outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
- outA = 6;
- outB = 0;
- return;
-
- case 4: // Tanh
- outArmnnActivation = armnn::ActivationFunction::TanH;
- outA = 1;
- outB = 1;
- return;
-
- case 6: // Sigmoid
- outArmnnActivation = armnn::ActivationFunction::Sigmoid;
- outA = 0;
- outB = 0;
- return;
-
- default:
- throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
- }
-}
-
-std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
-{
- if (!ptr)
- {
- return nullptr;
- }
-
- return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
-}
-
-} // anonymous namespace
-
-namespace armnn
-{
-
-RefLstmFloat32Workload::RefLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
- : Float32Workload<LstmQueueDescriptor>(descriptor, info)
- , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
- , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
- , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
- , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
- , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
- , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
- , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
- , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
- , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
- , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
- , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
- , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
- , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
- , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
- , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
- , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
- , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
-{}
-
-void RefLstmFloat32Workload::Execute() const
-{
- // This is a porting of the LSTM::Eval() method in the Android code base
- // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
-
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorShape& inputShape = inputInfo.GetShape();
-
- float* scratchBuffer = GetOutputTensorDataFloat(0, m_Data);
- float* outputStateOut = GetOutputTensorDataFloat(1, m_Data);
- float* cellStateOut = GetOutputTensorDataFloat(2, m_Data);
- float* output = GetOutputTensorDataFloat(3, m_Data);
-
- const float* inputData = GetInputTensorDataFloat(0, m_Data);
- const float* outputStateIn = GetInputTensorDataFloat(1, m_Data);
- const float* cellStateIn = GetInputTensorDataFloat(2, m_Data);
-
- const uint32_t nBatch = inputShape[0];
- const uint32_t nInput = inputShape[1];
-
- const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
- const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
-
- const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
- const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
-
- // Index the scratch buffers pointers to the global scratch buffer.
- float* inputGateScratch = nullptr;
- float* cellScratch = nullptr;
- float* forgetGateScratch = nullptr;
- float* outputGateScratch = nullptr;
-
- if (useCifg)
- {
- cellScratch = scratchBuffer + 0 * nCell * nBatch;
- forgetGateScratch = scratchBuffer + 1 * nCell * nBatch;
- outputGateScratch = scratchBuffer + 2 * nCell * nBatch;
- }
- else
- {
- inputGateScratch = scratchBuffer + 0 * nCell * nBatch;
- cellScratch = scratchBuffer + 1 * nCell * nBatch;
- forgetGateScratch = scratchBuffer + 2 * nCell * nBatch;
- outputGateScratch = scratchBuffer + 3 * nCell * nBatch;
- }
-
- // Initialize scratch buffers with bias.
- if (!useCifg)
- {
- VectorBatchVectorAssign(m_InputGateBiasTensor->GetTensor<float>(),
- nCell, nBatch, inputGateScratch);
- }
- VectorBatchVectorAssign(m_ForgetGateBiasTensor->GetTensor<float>(),
- nCell, nBatch, forgetGateScratch);
- VectorBatchVectorAssign(m_CellBiasTensor->GetTensor<float>(),
- nCell, nBatch, cellScratch);
- VectorBatchVectorAssign(m_OutputGateBiasTensor->GetTensor<float>(),
- nCell, nBatch, outputGateScratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!useCifg)
- {
- MatrixBatchVectorMultiplyAccumulate(m_InputToInputWeightsTensor->GetTensor<float>(),
- nCell, nInput, inputData, nBatch, inputGateScratch);
- }
- MatrixBatchVectorMultiplyAccumulate(m_InputToForgetWeightsTensor->GetTensor<float>(),
- nCell, nInput, inputData, nBatch, forgetGateScratch);
- MatrixBatchVectorMultiplyAccumulate(m_InputToCellWeightsTensor->GetTensor<float>(),
- nCell, nInput, inputData, nBatch, cellScratch);
- MatrixBatchVectorMultiplyAccumulate(m_InputToOutputWeightsTensor->GetTensor<float>(),
- nCell, nInput, inputData, nBatch, outputGateScratch);
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!useCifg)
- {
- MatrixBatchVectorMultiplyAccumulate(m_RecurrentToInputWeightsTensor->GetTensor<float>(),
- nCell, nOutput, outputStateIn, nBatch, inputGateScratch);
- }
- MatrixBatchVectorMultiplyAccumulate(m_RecurrentToForgetWeightsTensor->GetTensor<float>(),
- nCell, nOutput, outputStateIn, nBatch, forgetGateScratch);
- MatrixBatchVectorMultiplyAccumulate(m_RecurrentToCellWeightsTensor->GetTensor<float>(),
- nCell, nOutput, outputStateIn, nBatch, cellScratch);
- MatrixBatchVectorMultiplyAccumulate(m_RecurrentToOutputWeightsTensor->GetTensor<float>(),
- nCell, nOutput, outputStateIn, nBatch, outputGateScratch);
-
- // For each batch and cell: update input gate.
- if (!useCifg)
- {
- if (usePeephole)
- {
- VectorBatchVectorCwiseProductAccumulate(m_CellToInputWeightsTensor->GetTensor<float>(),
- nCell, cellStateIn, nBatch, inputGateScratch);
- }
- Activation(inputGateScratch, inputGateScratch,
- TensorInfo({nCell, nBatch}, DataType::Float32),
- ActivationFunction::Sigmoid, 0, 0);
- }
-
- // For each batch and cell: update forget gate.
- if (usePeephole)
- {
- VectorBatchVectorCwiseProductAccumulate(m_CellToForgetWeightsTensor->GetTensor<float>(), nCell,
- cellStateIn, nBatch, forgetGateScratch);
- }
- Activation(forgetGateScratch, forgetGateScratch,
- TensorInfo({nCell, nBatch}, DataType::Float32),
- ActivationFunction::Sigmoid, 0, 0);
-
- // For each batch and cell: update the cell.
- VectorVectorCwiseProduct(forgetGateScratch, cellStateIn, nBatch * nCell, cellStateOut);
-
- ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
- float a = 0;
- float b = 0;
- SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
-
- if (m_Data.m_Parameters.m_ActivationFunc > 0)
- {
- Activation(cellScratch, cellScratch,
- TensorInfo({nCell, nBatch}, DataType::Float32),
- armnnActivationFunc, a, b);
- }
- if (useCifg)
- {
- Sub1Vector(forgetGateScratch, nBatch * nCell, forgetGateScratch);
- VectorVectorCwiseProductAccumulate(cellScratch, forgetGateScratch, nBatch * nCell, cellStateOut);
- }
- else
- {
- VectorVectorCwiseProductAccumulate(cellScratch, inputGateScratch, nBatch * nCell, cellStateOut);
- }
- if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
- {
- ClipVector(cellStateOut, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, cellStateOut);
- }
-
- // For each batch and cell: update the output gate.
- if (usePeephole)
- {
- VectorBatchVectorCwiseProductAccumulate(m_CellToOutputWeightsTensor->GetTensor<float>(),
- nCell, cellStateOut, nBatch, outputGateScratch);
- }
- Activation(outputGateScratch, outputGateScratch,
- TensorInfo({nCell, nBatch}, DataType::Float32),
- ActivationFunction::Sigmoid, 0, 0);
-
- if (m_Data.m_Parameters.m_ActivationFunc > 0)
- {
- Activation(cellStateOut, cellScratch,
- TensorInfo({nCell, nBatch}, DataType::Float32),
- armnnActivationFunc, a, b);
- }
- VectorVectorCwiseProduct(outputGateScratch, cellScratch, nBatch * nCell, outputGateScratch);
-
- // For each batch: update the projection and output_state.
- if (m_Data.m_Parameters.m_ProjectionEnabled)
- {
- if (m_ProjectionBiasTensor)
- {
- VectorBatchVectorAssign(m_ProjectionBiasTensor->GetTensor<float>(),
- nOutput, nBatch, output);
- }
- MatrixBatchVectorMultiplyAccumulate(m_ProjectionWeightsTensor->GetTensor<float>(),
- nOutput, nCell, outputGateScratch, nBatch, output);
-
- if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
- {
- ClipVector(output, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, output);
- }
- }
- else
- {
- CopyVector(outputGateScratch, nBatch * nOutput, output);
- }
-
- CopyVector(output, nBatch * nOutput, outputStateOut);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
new file mode 100644
index 0000000000..f8ebc58f6e
--- /dev/null
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -0,0 +1,307 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefLstmWorkload.hpp"
+#include "Activation.hpp"
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+#include "LstmUtils.hpp"
+#include "RefWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
+ : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
+ , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
+ , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
+ , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
+ , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
+ , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
+ , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
+ , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
+ , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
+ , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
+ , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
+ , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
+ , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
+ , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
+ , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
+ , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
+ , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
+ , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
+{}
+
+void RefLstmWorkload::Execute() const
+{
+ // This is a porting of the LSTM::Eval() method in the Android code base
+ // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ const DataType& outputType = outputInfo.GetDataType();
+
+ std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map());
+ std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+ std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+ std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+ std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+ std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+ std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map());
+ std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map());
+
+ const uint32_t nBatch = inputShape[0];
+ const uint32_t nInput = inputShape[1];
+
+ const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
+ const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
+
+ const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
+ const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+ std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> cellScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+ std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
+ MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+ if (useCifg)
+ {
+ *cellScratch += (0 * nCell * nBatch);
+ *forgetGateScratch += (1 * nCell * nBatch);
+ *outputGateScratch += (2 * nCell * nBatch);
+
+ *cellScratchDecoder += (0 * nCell * nBatch);
+ *forgetGateScratchDecoder += (1 * nCell * nBatch);
+ *outputGateScratchDecoder += (2 * nCell * nBatch);
+ }
+ else
+ {
+ *inputGateScratch += (0 * nCell * nBatch);
+ *cellScratch += (1 * nCell * nBatch);
+ *forgetGateScratch += (2 * nCell * nBatch);
+ *outputGateScratch += (3 * nCell * nBatch);
+
+ *inputGateScratchDecoder += (0 * nCell * nBatch);
+ *cellScratchDecoder += (1 * nCell * nBatch);
+ *forgetGateScratchDecoder += (2 * nCell * nBatch);
+ *outputGateScratchDecoder += (3 * nCell * nBatch);
+ }
+
+ std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
+ m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
+ m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
+ m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> inputGateBiasTensor;
+ std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
+ m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
+ m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>());
+ std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
+ m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>());
+
+ std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
+ std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
+ std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
+
+ std::unique_ptr<Decoder<float>> projectionWeightsTensor;
+ std::unique_ptr<Decoder<float>> projectionBiasTensor;
+
+ if (!useCifg)
+ {
+ inputToInputWeightsTensor = MakeDecoder<float>(
+ m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
+ inputGateBiasTensor = MakeDecoder<float>(
+ m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>());
+ recurrentToInputWeightsTensor = MakeDecoder<float>(
+ m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
+ }
+
+ if (usePeephole)
+ {
+ cellToForgetWeightsTensor = MakeDecoder<float>(
+ m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
+ cellToOutputWeightsTensor = MakeDecoder<float>(
+ m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
+ }
+
+ if (!useCifg && usePeephole)
+ {
+ cellToInputWeightsTensor = MakeDecoder<float>(
+ m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
+ }
+
+ if (m_Data.m_Parameters.m_ProjectionEnabled)
+ {
+ projectionWeightsTensor = MakeDecoder<float>(
+ m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
+ if (m_ProjectionBiasTensor)
+ {
+ projectionBiasTensor = MakeDecoder<float>(
+ m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
+ }
+ }
+
+ // Initialize scratch buffers with bias.
+ if (!useCifg)
+ {
+ VectorBatchVectorAssign(*inputGateBiasTensor,
+ nCell, nBatch, *inputGateScratch);
+ }
+ VectorBatchVectorAssign(*forgetGateBiasTensor,
+ nCell, nBatch, *forgetGateScratch);
+ VectorBatchVectorAssign(*cellBiasTensor,
+ nCell, nBatch, *cellScratch);
+ VectorBatchVectorAssign(*outputGateBiasTensor,
+ nCell, nBatch, *outputGateScratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
+ nCell, nInput, *inputData, nBatch, *outputGateScratch);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!useCifg)
+ {
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
+ MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
+ nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
+
+ // For each batch and cell: update input gate.
+ if (!useCifg)
+ {
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
+ nCell, *cellStateIn, nBatch, *inputGateScratch);
+ }
+ Activation(*inputGateScratchDecoder, *inputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
+ *cellStateIn, nBatch, *forgetGateScratch);
+ }
+ Activation(*forgetGateScratchDecoder, *forgetGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ // For each batch and cell: update the cell.
+ VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
+
+ ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
+ float a = 0;
+ float b = 0;
+ SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
+
+ if (m_Data.m_Parameters.m_ActivationFunc > 0)
+ {
+ Activation(*cellScratchDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+ if (useCifg)
+ {
+ Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ else
+ {
+ VectorVectorCwiseProductAccumulate(
+ *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
+ }
+ if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
+ {
+ ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (usePeephole)
+ {
+ VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
+ nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
+ }
+ Activation(*outputGateScratchDecoder, *outputGateScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ ActivationFunction::Sigmoid, 0, 0);
+
+ if (m_Data.m_Parameters.m_ActivationFunc > 0)
+ {
+ Activation(*cellStateOutDecoder, *cellScratch,
+ TensorInfo({nCell, nBatch}, outputType),
+ armnnActivationFunc, a, b);
+ }
+
+ VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
+
+ // For each batch: update the projection and output_state.
+ if (m_Data.m_Parameters.m_ProjectionEnabled)
+ {
+ if (m_ProjectionBiasTensor)
+ {
+ VectorBatchVectorAssign(*projectionBiasTensor,
+ nOutput, nBatch, *output);
+ }
+ MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
+ nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
+
+ if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
+ {
+ ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
+ }
+ }
+ else
+ {
+ CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
+ }
+
+ CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp
index a2dead8b9c..38e3fb956c 100644
--- a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.hpp
@@ -13,10 +13,10 @@
namespace armnn
{
-class RefLstmFloat32Workload : public Float32Workload<LstmQueueDescriptor>
+class RefLstmWorkload : public BaseWorkload<LstmQueueDescriptor>
{
public:
- explicit RefLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
+ explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
virtual void Execute() const override;
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7871a1b806..8ffd3485ae 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -51,7 +51,7 @@
#include "Pooling2d.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
#include "RefPermuteWorkload.hpp"
-#include "RefLstmFloat32Workload.hpp"
+#include "RefLstmWorkload.hpp"
#include "RefConvertFp16ToFp32Workload.hpp"
#include "RefConvertFp32ToFp16Workload.hpp"
#include "RefMeanUint8Workload.hpp"