diff options
Diffstat (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 34914fca50..d8b2d4f786 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -1297,4 +1297,36 @@ TEST_CASE_FIXTURE(ClContextControlFixture, "CreateQuantizedLstmWorkload") ClCreateQuantizedLstmWorkloadTest<ClQuantizedLstmWorkload>(); } +template <armnn::DataType DataType> +static void ClCreateActivationWorkloadReplaceFunctionsTest() +{ + std::shared_ptr<ClMemoryManager> memoryManager = std::make_shared<ClMemoryManager>( + std::make_unique<arm_compute::CLBufferAllocator>()); + + Graph graph; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(memoryManager); + // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType) + auto workloadPtr = CreateActivationWorkloadTest<ClActivationWorkload, DataType>(factory, graph); + + // new input and output tensor handlers are created and then replace in the workload + const ClTensorHandleFactory tensorHandleFactory(memoryManager); + TensorInfo inputInfo({2 , 2}, DataType::Float16); + TensorInfo outputInfo({2 , 2}, DataType::Float16); + unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo, true); + inputHandle->Manage(); + inputHandle->Allocate(); + unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo, true); + outputHandle->Manage(); + outputHandle->Allocate(); + + unsigned int slot = 0; + CHECK_THROWS_AS(workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot), UnimplementedException); + CHECK_THROWS_AS(workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot), UnimplementedException); +} + +TEST_CASE("ClReplaceFunctionsfromFloat32toFloat16ActivationWorkload") +{ + ClCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>(); +} + } |