// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include #include #include #include namespace armnn { class ITensorHandle; } // namespace armnn namespace { template void AddInputToWorkload(QueueDescriptor& descriptor, armnn::WorkloadInfo& info, const armnn::TensorInfo& tensorInfo, armnn::ITensorHandle* tensorHandle) { descriptor.m_Inputs.push_back(tensorHandle); info.m_InputTensorInfos.push_back(tensorInfo); } template void AddOutputToWorkload(QueueDescriptor& descriptor, armnn::WorkloadInfo& info, const armnn::TensorInfo& tensorInfo, armnn::ITensorHandle* tensorHandle) { descriptor.m_Outputs.push_back(tensorHandle); info.m_OutputTensorInfos.push_back(tensorInfo); } template void SetWorkloadInput(QueueDescriptor& descriptor, armnn::WorkloadInfo& info, unsigned int index, const armnn::TensorInfo& tensorInfo, armnn::ITensorHandle* tensorHandle) { descriptor.m_Inputs[index] = tensorHandle; info.m_InputTensorInfos[index] = tensorInfo; } template void SetWorkloadOutput(QueueDescriptor& descriptor, armnn::WorkloadInfo& info, unsigned int index, const armnn::TensorInfo& tensorInfo, armnn::ITensorHandle* tensorHandle) { descriptor.m_Outputs[index] = tensorHandle; info.m_OutputTensorInfos[index] = tensorInfo; } inline void ExecuteWorkload(armnn::IWorkload& workload, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, bool memoryManagementRequested = true) { const bool manageMemory = memoryManager && memoryManagementRequested; // Acquire working memory (if needed) if (manageMemory) { memoryManager->Acquire(); } // Perform PostAllocationConfiguration workload.PostAllocationConfigure(); // Execute the workload workload.Execute(); // Release working memory (if needed) if (manageMemory) { memoryManager->Release(); } } inline armnn::Optional GetBiasTypeFromWeightsType(armnn::Optional weightsType) { if (!weightsType) { return weightsType; } switch(weightsType.value()) { case armnn::DataType::Float16: case armnn::DataType::Float32: return weightsType; case armnn::DataType::QuantisedAsymm8: return armnn::DataType::Signed32; case armnn::DataType::QuantisedSymm16: return armnn::DataType::Signed32; default: BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); } return armnn::EmptyOptional(); } } // anonymous namespace