aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp')
-rw-r--r--src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp18
1 files changed, 10 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
index c04e97bb0f..a69f7270b4 100644
--- a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
@@ -314,7 +314,7 @@ LstmNoCifgNoPeepholeNoProjectionTestImpl(
data.m_Parameters.m_PeepholeEnabled = false;
data.m_Parameters.m_ProjectionEnabled = false;
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -987,7 +987,7 @@ LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workl
data.m_Parameters.m_PeepholeEnabled = true;
data.m_Parameters.m_ProjectionEnabled = true;
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -1211,7 +1211,7 @@ LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(
AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
@@ -1464,7 +1464,7 @@ LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadF
data.m_Parameters.m_LayerNormEnabled = true;
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -1653,7 +1653,9 @@ LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl(
data.m_OutputGateBias = &outputGateBiasTensor;
// Create workload and allocate tensor handles
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQuantizedLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QuantizedLstm,
+ data,
+ info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -1890,7 +1892,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl(
data.m_Parameters.m_ProjectionClip = projectionClip;
// Create workload and allocate tensor handles
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -2155,7 +2157,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl1(
data.m_Parameters.m_ProjectionClip = projectionClip;
// Create workload and allocate tensor handles
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();
@@ -2406,7 +2408,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl2(
data.m_Parameters.m_ProjectionClip = projectionClip;
// Create workload and allocate tensor handles
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
inputHandle->Allocate();
outputStateInHandle->Allocate();
cellStateInHandle->Allocate();