diff options
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp index c04e97bb0f..a69f7270b4 100644 --- a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp @@ -314,7 +314,7 @@ LstmNoCifgNoPeepholeNoProjectionTestImpl( data.m_Parameters.m_PeepholeEnabled = false; data.m_Parameters.m_ProjectionEnabled = false; - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -987,7 +987,7 @@ LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workl data.m_Parameters.m_PeepholeEnabled = true; data.m_Parameters.m_ProjectionEnabled = true; - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -1211,7 +1211,7 @@ LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl( AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get()); AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); @@ -1464,7 +1464,7 @@ LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadF data.m_Parameters.m_LayerNormEnabled = true; - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -1653,7 +1653,9 @@ LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl( data.m_OutputGateBias = &outputGateBiasTensor; // Create workload and allocate tensor handles - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQuantizedLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QuantizedLstm, + data, + info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -1890,7 +1892,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl( data.m_Parameters.m_ProjectionClip = projectionClip; // Create workload and allocate tensor handles - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -2155,7 +2157,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl1( data.m_Parameters.m_ProjectionClip = projectionClip; // Create workload and allocate tensor handles - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); @@ -2406,7 +2408,7 @@ LayerTestResult<int8_t, 2> QLstmTestImpl2( data.m_Parameters.m_ProjectionClip = projectionClip; // Create workload and allocate tensor handles - std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); + std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info); inputHandle->Allocate(); outputStateInHandle->Allocate(); cellStateInHandle->Allocate(); |