diff options
Diffstat (limited to 'src/backends/neon/test/NeonCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/neon/test/NeonCreateWorkloadTests.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index c1563fe046..66718cc481 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -1059,4 +1059,38 @@ TEST_CASE("CreateQLstmWorkloadTest") NeonCreateQLstmWorkloadTest<NeonQLstmWorkload>(); } +template <armnn::DataType DataType> +static void NeonCreateActivationWorkloadReplaceFunctionsTest() +{ + shared_ptr<NeonMemoryManager> memoryManager = make_shared<NeonMemoryManager>(); + + Graph graph; + NeonWorkloadFactory factory = NeonWorkloadFactoryHelper::GetFactory(memoryManager); + // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType) + auto workloadPtr = CreateActivationWorkloadTest<NeonActivationWorkload, DataType>(factory, graph); + + // new input and output tensor handlers are created and then replace in the workload + const NeonTensorHandleFactory tensorHandleFactory(memoryManager); + TensorInfo inputInfo({2 , 2}, DataType::Float16); + TensorInfo outputInfo({2 , 2}, DataType::Float16); + unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); + inputHandle->Allocate(); + unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); + outputHandle->Allocate(); + + unsigned int slot = 0; + CHECK_THROWS_AS(workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot), UnimplementedException); + CHECK_THROWS_AS(workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot), UnimplementedException); +} + +TEST_CASE("NeonReplaceFunctionsfromFloat32toFloat16ActivationWorkload") +{ + NeonCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>(); +} + +TEST_CASE("NeonReplaceFunctionsfromUint8toFloat16ActivationWorkload") +{ + NeonCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>(); +} + } |