diff options
author | Colm Donelan <colm.donelan@arm.com> | 2021-12-10 12:43:54 +0000 |
---|---|---|
committer | Colm Donelan <colm.donelan@arm.com> | 2021-12-15 12:53:20 +0000 |
commit | 0c47974f1800e8770904aecaef15d6f105758c4e (patch) | |
tree | f5424858c6fe6f33376b3432580179958ab8ac5a /include/armnnTestUtils/WorkloadTestUtils.hpp | |
parent | cdbb09f6e15ea6698a62197cf76ecba87b81cb9d (diff) | |
download | armnn-0c47974f1800e8770904aecaef15d6f105758c4e.tar.gz |
IVGCVSW-6626 Promote backend headers in backendCommon to armnn/backends
Move the following header files from backendsCommon to armnn/backends.
* MemCopyWorkload.hpp
* TensorHandle.hpp
* Workload.hpp
* WorkloadData.hpp
* WorkloadFactory.hpp
Replace them with forwarding headers and a pragma deprecation message.
Resolve the deprecation messages in Arm NN code.
Signed-off-by: Colm Donelan <colm.donelan@arm.com>
Change-Id: I47f116b30f86e478c9057795bc518c391a8ae514
Diffstat (limited to 'include/armnnTestUtils/WorkloadTestUtils.hpp')
-rw-r--r-- | include/armnnTestUtils/WorkloadTestUtils.hpp | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/include/armnnTestUtils/WorkloadTestUtils.hpp b/include/armnnTestUtils/WorkloadTestUtils.hpp new file mode 100644 index 0000000000..156258a549 --- /dev/null +++ b/include/armnnTestUtils/WorkloadTestUtils.hpp @@ -0,0 +1,113 @@ +// +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include <armnn/Tensor.hpp> + +#include <armnn/backends/IBackendInternal.hpp> +#include <armnn/backends/IMemoryManager.hpp> +#include <armnn/backends/Workload.hpp> +#include <armnn/backends/WorkloadInfo.hpp> + +namespace armnn +{ +class ITensorHandle; +} // namespace armnn + +namespace +{ + +template <typename QueueDescriptor> +void AddInputToWorkload(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Inputs.push_back(tensorHandle); + info.m_InputTensorInfos.push_back(tensorInfo); +} + +template <typename QueueDescriptor> +void AddOutputToWorkload(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Outputs.push_back(tensorHandle); + info.m_OutputTensorInfos.push_back(tensorInfo); +} + +template <typename QueueDescriptor> +void SetWorkloadInput(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + unsigned int index, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Inputs[index] = tensorHandle; + info.m_InputTensorInfos[index] = tensorInfo; +} + +template <typename QueueDescriptor> +void SetWorkloadOutput(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + unsigned int index, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Outputs[index] = tensorHandle; + info.m_OutputTensorInfos[index] = tensorInfo; +} + +inline void ExecuteWorkload(armnn::IWorkload& workload, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + bool memoryManagementRequested = true) +{ + const bool manageMemory = memoryManager && memoryManagementRequested; + + // Acquire working memory (if needed) + if (manageMemory) + { + memoryManager->Acquire(); + } + + // Perform PostAllocationConfiguration + workload.PostAllocationConfigure(); + + // Execute the workload + workload.Execute(); + + // Release working memory (if needed) + if (manageMemory) + { + memoryManager->Release(); + } +} + +inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType) +{ + if (!weightsType) + { + return weightsType; + } + + switch(weightsType.value()) + { + case armnn::DataType::BFloat16: + case armnn::DataType::Float16: + case armnn::DataType::Float32: + return weightsType; + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS8: + case armnn::DataType::QSymmS16: + return armnn::DataType::Signed32; + default: + ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return armnn::EmptyOptional(); +} + +} // anonymous namespace |