From c9cc80455ff29fd2c8622c9487ec9c57ade6ea30 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Thu, 1 Nov 2018 16:15:57 +0000 Subject: IVGCVSW-1946: Remove armnn/src from the include paths Change-Id: I663a0a0fccb43ee960ec070121a59df9db0bb04e --- src/backends/backendsCommon/MakeWorkloadHelper.hpp | 82 ++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/backends/backendsCommon/MakeWorkloadHelper.hpp (limited to 'src/backends/backendsCommon/MakeWorkloadHelper.hpp') diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp new file mode 100644 index 0000000000..78a9669530 --- /dev/null +++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp @@ -0,0 +1,82 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +namespace armnn +{ +namespace +{ + +// Make a workload of the specified WorkloadType. +template +struct MakeWorkloadForType +{ + template + static std::unique_ptr Func(const QueueDescriptorType& descriptor, + const WorkloadInfo& info, + Args&&... args) + { + return std::make_unique(descriptor, info, std::forward(args)...); + } +}; + +// Specialization for void workload type used for unsupported workloads. +template<> +struct MakeWorkloadForType +{ + template + static std::unique_ptr Func(const QueueDescriptorType& descriptor, + const WorkloadInfo& info, + Args&&... args) + { + return nullptr; + } +}; + +// Makes a workload for one the specified types based on the data type requirements of the tensorinfo. +// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos. +template +std::unique_ptr MakeWorkloadHelper(const QueueDescriptorType& descriptor, + const WorkloadInfo& info, + Args&&... args) +{ + const DataType dataType = !info.m_InputTensorInfos.empty() ? + info.m_InputTensorInfos[0].GetDataType() + : info.m_OutputTensorInfos[0].GetDataType(); + + BOOST_ASSERT(info.m_InputTensorInfos.empty() || info.m_OutputTensorInfos.empty() + || info.m_InputTensorInfos[0].GetDataType() == info.m_OutputTensorInfos[0].GetDataType()); + + switch (dataType) + { + case DataType::Float16: + return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); + case DataType::Float32: + return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); + case DataType::QuantisedAsymm8: + return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); + default: + BOOST_ASSERT_MSG(false, "Unknown DataType."); + return nullptr; + } +} + +// Makes a workload for one the specified types based on the data type requirements of the tensorinfo. +// Calling this method is the equivalent of calling the three typed MakeWorkload method with . +// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos. +template +std::unique_ptr MakeWorkloadHelper(const QueueDescriptorType& descriptor, + const WorkloadInfo& info, + Args&&... args) +{ + return MakeWorkloadHelper(descriptor, info, + std::forward(args)...); +} + + +} //namespace +} //namespace armnn -- cgit v1.2.1