aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-01-17 21:19:52 +0000
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-01-18 21:32:27 +0000
commit788e2a6c917abe7c5187d9e5c349683d456080e5 (patch)
tree1c88b36828a3ebb6ef04c3e8c0bc92bdd330e9cf
parentadeebaa73205bd981ea7e8c8f135f01355cba841 (diff)
downloadarmnn-788e2a6c917abe7c5187d9e5c349683d456080e5.tar.gz
IVGCVSW-6682 Add ReplaceTensorHandle functions to IWorkload and BaseWorkload
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I9f80b9f45206db920568e28e363fcb60f5c0819a
-rw-r--r--include/armnn/backends/IWorkload.hpp6
-rw-r--r--include/armnn/backends/Workload.hpp12
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp47
-rw-r--r--src/backends/reference/workloads/RefActivationWorkload.hpp4
4 files changed, 65 insertions, 4 deletions
diff --git a/include/armnn/backends/IWorkload.hpp b/include/armnn/backends/IWorkload.hpp
index a4827ebcdf..d63e0acc72 100644
--- a/include/armnn/backends/IWorkload.hpp
+++ b/include/armnn/backends/IWorkload.hpp
@@ -31,6 +31,12 @@ public:
virtual profiling::ProfilingGuid GetGuid() const = 0;
+ // Replace input tensor handle with the given TensorHandle
+ virtual void ReplaceInputTensorHandle(ITensorHandle* /*input*/, unsigned int /*slot*/) = 0;
+
+ // Replace output tensor handle with the given TensorHandle
+ virtual void ReplaceOutputTensorHandle(ITensorHandle* /*output*/, unsigned int /*slot*/) = 0;
+
virtual void RegisterDebugCallback(const DebugCallbackFunction& /*func*/) {}
};
diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp
index 7c1bda50bc..07e1abb392 100644
--- a/include/armnn/backends/Workload.hpp
+++ b/include/armnn/backends/Workload.hpp
@@ -54,6 +54,18 @@ public:
profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
+ {
+ m_Data.m_Inputs[slot] = tensorHandle;
+ }
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
+ {
+ m_Data.m_Outputs[slot] = tensorHandle;
+ }
+
protected:
QueueDescriptor m_Data;
const profiling::ProfilingGuid m_Guid;
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index e865f25f49..6dbbd556a9 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -7,6 +7,7 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <reference/RefTensorHandle.hpp>
+#include <reference/RefTensorHandleFactory.hpp>
#include <reference/RefWorkloadFactory.hpp>
#include <reference/workloads/RefWorkloads.hpp>
@@ -46,7 +47,6 @@ armnn::RefWorkloadFactory GetFactory()
return RefWorkloadFactory(memoryManager);
}
-
}
TEST_SUITE("CreateWorkloadRef")
@@ -1271,4 +1271,47 @@ TEST_CASE("CreateQLstmWorkload")
RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
}
+template <armnn::DataType DataType>
+static void RefCreateActivationWorkloadReplaceFunctionsTest()
+{
+ Graph graph;
+ RefWorkloadFactory factory = GetFactory();
+ // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
+ auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
+
+ // new input and output tensor handlers are created and then replace in the workload
+ shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
+ const RefTensorHandleFactory tensorHandleFactory(memoryManager);
+ TensorInfo inputInfo({2 , 2}, DataType::Float16);
+ TensorInfo outputInfo({2 , 2}, DataType::Float16);
+ unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
+ unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
+ unsigned int slot = 0;
+ workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
+ workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
+
+ // Check if the tensor handlers inside the workload are the same as ones we replace with
+ auto queueDescriptor = workloadPtr->GetData();
+ auto inputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
+ CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
+ CHECK(inputHandle.get() == inputHandleTest);
+ CHECK(outputHandle.get() == outputHandleTest);
+ inputHandle->Allocate();
+ CHECK(inputHandle->Map() == inputHandleTest->Map());
+ outputHandle->Allocate();
+ CHECK(outputHandle->Map() == outputHandleTest->Map());
+}
+
+TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
+{
+ RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
+}
+
+TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
+{
+ RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
+}
+
}
diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp
index e3bd8706a4..9814ac172b 100644
--- a/src/backends/reference/workloads/RefActivationWorkload.hpp
+++ b/src/backends/reference/workloads/RefActivationWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -16,7 +16,7 @@ class RefActivationWorkload : public BaseWorkload<ActivationQueueDescriptor>
public:
using BaseWorkload<ActivationQueueDescriptor>::BaseWorkload;
void Execute() const override;
- void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
+ void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;