diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-01-17 21:19:52 +0000 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-01-18 21:32:27 +0000 |
commit | 788e2a6c917abe7c5187d9e5c349683d456080e5 (patch) | |
tree | 1c88b36828a3ebb6ef04c3e8c0bc92bdd330e9cf /src | |
parent | adeebaa73205bd981ea7e8c8f135f01355cba841 (diff) | |
download | armnn-788e2a6c917abe7c5187d9e5c349683d456080e5.tar.gz |
IVGCVSW-6682 Add ReplaceTensorHandle functions to IWorkload and BaseWorkload
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I9f80b9f45206db920568e28e363fcb60f5c0819a
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 47 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefActivationWorkload.hpp | 4 |
2 files changed, 47 insertions, 4 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index e865f25f49..6dbbd556a9 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -7,6 +7,7 @@ #include <armnn/utility/PolymorphicDowncast.hpp> #include <reference/RefTensorHandle.hpp> +#include <reference/RefTensorHandleFactory.hpp> #include <reference/RefWorkloadFactory.hpp> #include <reference/workloads/RefWorkloads.hpp> @@ -46,7 +47,6 @@ armnn::RefWorkloadFactory GetFactory() return RefWorkloadFactory(memoryManager); } - } TEST_SUITE("CreateWorkloadRef") @@ -1271,4 +1271,47 @@ TEST_CASE("CreateQLstmWorkload") RefCreateQLstmWorkloadTest<RefQLstmWorkload>(); } +template <armnn::DataType DataType> +static void RefCreateActivationWorkloadReplaceFunctionsTest() +{ + Graph graph; + RefWorkloadFactory factory = GetFactory(); + // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType) + auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph); + + // new input and output tensor handlers are created and then replace in the workload + shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>(); + const RefTensorHandleFactory tensorHandleFactory(memoryManager); + TensorInfo inputInfo({2 , 2}, DataType::Float16); + TensorInfo outputInfo({2 , 2}, DataType::Float16); + unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); + unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); + unsigned int slot = 0; + workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot); + workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot); + + // Check if the tensor handlers inside the workload are the same as ones we replace with + auto queueDescriptor = workloadPtr->GetData(); + auto inputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]); + auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]); + CHECK((inputHandleTest->GetTensorInfo() == inputInfo)); + CHECK((outputHandleTest->GetTensorInfo() == outputInfo)); + CHECK(inputHandle.get() == inputHandleTest); + CHECK(outputHandle.get() == outputHandleTest); + inputHandle->Allocate(); + CHECK(inputHandle->Map() == inputHandleTest->Map()); + outputHandle->Allocate(); + CHECK(outputHandle->Map() == outputHandleTest->Map()); +} + +TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload") +{ + RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>(); +} + +TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload") +{ + RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>(); +} + } diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp index e3bd8706a4..9814ac172b 100644 --- a/src/backends/reference/workloads/RefActivationWorkload.hpp +++ b/src/backends/reference/workloads/RefActivationWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -16,7 +16,7 @@ class RefActivationWorkload : public BaseWorkload<ActivationQueueDescriptor> public: using BaseWorkload<ActivationQueueDescriptor>::BaseWorkload; void Execute() const override; - void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; |