From 788e2a6c917abe7c5187d9e5c349683d456080e5 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Mon, 17 Jan 2022 21:19:52 +0000 Subject: IVGCVSW-6682 Add ReplaceTensorHandle functions to IWorkload and BaseWorkload Signed-off-by: Teresa Charlin Change-Id: I9f80b9f45206db920568e28e363fcb60f5c0819a --- include/armnn/backends/IWorkload.hpp | 6 +++ include/armnn/backends/Workload.hpp | 12 ++++++ .../reference/test/RefCreateWorkloadTests.cpp | 47 +++++++++++++++++++++- .../reference/workloads/RefActivationWorkload.hpp | 4 +- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/include/armnn/backends/IWorkload.hpp b/include/armnn/backends/IWorkload.hpp index a4827ebcdf..d63e0acc72 100644 --- a/include/armnn/backends/IWorkload.hpp +++ b/include/armnn/backends/IWorkload.hpp @@ -31,6 +31,12 @@ public: virtual profiling::ProfilingGuid GetGuid() const = 0; + // Replace input tensor handle with the given TensorHandle + virtual void ReplaceInputTensorHandle(ITensorHandle* /*input*/, unsigned int /*slot*/) = 0; + + // Replace output tensor handle with the given TensorHandle + virtual void ReplaceOutputTensorHandle(ITensorHandle* /*output*/, unsigned int /*slot*/) = 0; + virtual void RegisterDebugCallback(const DebugCallbackFunction& /*func*/) {} }; diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp index 7c1bda50bc..07e1abb392 100644 --- a/include/armnn/backends/Workload.hpp +++ b/include/armnn/backends/Workload.hpp @@ -54,6 +54,18 @@ public: profiling::ProfilingGuid GetGuid() const final { return m_Guid; } + // Replace input tensor handle with the given TensorHandle + void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override + { + m_Data.m_Inputs[slot] = tensorHandle; + } + + // Replace output tensor handle with the given TensorHandle + void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override + { + m_Data.m_Outputs[slot] = tensorHandle; + } + protected: QueueDescriptor m_Data; const profiling::ProfilingGuid m_Guid; diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index e865f25f49..6dbbd556a9 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -46,7 +47,6 @@ armnn::RefWorkloadFactory GetFactory() return RefWorkloadFactory(memoryManager); } - } TEST_SUITE("CreateWorkloadRef") @@ -1271,4 +1271,47 @@ TEST_CASE("CreateQLstmWorkload") RefCreateQLstmWorkloadTest(); } +template +static void RefCreateActivationWorkloadReplaceFunctionsTest() +{ + Graph graph; + RefWorkloadFactory factory = GetFactory(); + // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType) + auto workloadPtr = CreateActivationWorkloadTest(factory, graph); + + // new input and output tensor handlers are created and then replace in the workload + shared_ptr memoryManager = make_shared(); + const RefTensorHandleFactory tensorHandleFactory(memoryManager); + TensorInfo inputInfo({2 , 2}, DataType::Float16); + TensorInfo outputInfo({2 , 2}, DataType::Float16); + unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); + unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); + unsigned int slot = 0; + workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot); + workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot); + + // Check if the tensor handlers inside the workload are the same as ones we replace with + auto queueDescriptor = workloadPtr->GetData(); + auto inputHandleTest = PolymorphicDowncast(queueDescriptor.m_Inputs[0]); + auto outputHandleTest = PolymorphicDowncast(queueDescriptor.m_Outputs[0]); + CHECK((inputHandleTest->GetTensorInfo() == inputInfo)); + CHECK((outputHandleTest->GetTensorInfo() == outputInfo)); + CHECK(inputHandle.get() == inputHandleTest); + CHECK(outputHandle.get() == outputHandleTest); + inputHandle->Allocate(); + CHECK(inputHandle->Map() == inputHandleTest->Map()); + outputHandle->Allocate(); + CHECK(outputHandle->Map() == outputHandleTest->Map()); +} + +TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload") +{ + RefCreateActivationWorkloadReplaceFunctionsTest(); +} + +TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload") +{ + RefCreateActivationWorkloadReplaceFunctionsTest(); +} + } diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp index e3bd8706a4..9814ac172b 100644 --- a/src/backends/reference/workloads/RefActivationWorkload.hpp +++ b/src/backends/reference/workloads/RefActivationWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -16,7 +16,7 @@ class RefActivationWorkload : public BaseWorkload public: using BaseWorkload::BaseWorkload; void Execute() const override; - void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: void Execute(std::vector inputs, std::vector outputs) const; -- cgit v1.2.1