aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp47
1 files changed, 45 insertions, 2 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index e865f25f49..6dbbd556a9 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -7,6 +7,7 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <reference/RefTensorHandle.hpp>
+#include <reference/RefTensorHandleFactory.hpp>
#include <reference/RefWorkloadFactory.hpp>
#include <reference/workloads/RefWorkloads.hpp>
@@ -46,7 +47,6 @@ armnn::RefWorkloadFactory GetFactory()
return RefWorkloadFactory(memoryManager);
}
-
}
TEST_SUITE("CreateWorkloadRef")
@@ -1271,4 +1271,47 @@ TEST_CASE("CreateQLstmWorkload")
RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
}
+template <armnn::DataType DataType>
+static void RefCreateActivationWorkloadReplaceFunctionsTest()
+{
+ Graph graph;
+ RefWorkloadFactory factory = GetFactory();
+ // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
+ auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
+
+ // new input and output tensor handlers are created and then replace in the workload
+ shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
+ const RefTensorHandleFactory tensorHandleFactory(memoryManager);
+ TensorInfo inputInfo({2 , 2}, DataType::Float16);
+ TensorInfo outputInfo({2 , 2}, DataType::Float16);
+ unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
+ unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
+ unsigned int slot = 0;
+ workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
+ workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
+
+ // Check if the tensor handlers inside the workload are the same as ones we replace with
+ auto queueDescriptor = workloadPtr->GetData();
+ auto inputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
+ CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
+ CHECK(inputHandle.get() == inputHandleTest);
+ CHECK(outputHandle.get() == outputHandleTest);
+ inputHandle->Allocate();
+ CHECK(inputHandle->Map() == inputHandleTest->Map());
+ outputHandle->Allocate();
+ CHECK(outputHandle->Map() == outputHandleTest->Map());
+}
+
+TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
+{
+ RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
+}
+
+TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
+{
+ RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
+}
+
}