aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/MakeWorkloadHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/MakeWorkloadHelper.hpp')
-rw-r--r--src/armnn/backends/MakeWorkloadHelper.hpp24
1 files changed, 14 insertions, 10 deletions
diff --git a/src/armnn/backends/MakeWorkloadHelper.hpp b/src/armnn/backends/MakeWorkloadHelper.hpp
index a8729eb07c..a1f9b0b0eb 100644
--- a/src/armnn/backends/MakeWorkloadHelper.hpp
+++ b/src/armnn/backends/MakeWorkloadHelper.hpp
@@ -13,10 +13,12 @@ namespace
template<typename WorkloadType>
struct MakeWorkloadForType
{
- template<typename QueueDescriptorType>
- static std::unique_ptr<WorkloadType> Func(const QueueDescriptorType& descriptor, const WorkloadInfo& info)
+ template<typename QueueDescriptorType, typename... Args>
+ static std::unique_ptr<WorkloadType> Func(const QueueDescriptorType& descriptor,
+ const WorkloadInfo& info,
+ Args&&... args)
{
- return std::make_unique<WorkloadType>(descriptor, info);
+ return std::make_unique<WorkloadType>(descriptor, info, std::forward<Args>(args)...);
}
};
@@ -24,8 +26,10 @@ struct MakeWorkloadForType
template<>
struct MakeWorkloadForType<NullWorkload>
{
- template<typename QueueDescriptorType>
- static std::unique_ptr<NullWorkload> Func(const QueueDescriptorType& descriptor, const WorkloadInfo& info)
+ template<typename QueueDescriptorType, typename... Args>
+ static std::unique_ptr<NullWorkload> Func(const QueueDescriptorType& descriptor,
+ const WorkloadInfo& info,
+ Args&&... args)
{
return nullptr;
}
@@ -33,8 +37,8 @@ struct MakeWorkloadForType<NullWorkload>
// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
-template <typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType>
-std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info)
+template <typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
+std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, Args&&... args)
{
const DataType dataType = !info.m_InputTensorInfos.empty() ?
info.m_InputTensorInfos[0].GetDataType()
@@ -46,9 +50,9 @@ std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, c
switch (dataType)
{
case DataType::Float32:
- return MakeWorkloadForType<Float32Workload>::Func(descriptor, info);
+ return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::QuantisedAsymm8:
- return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info);
+ return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
default:
BOOST_ASSERT_MSG(false, "Unknown DataType.");
return nullptr;
@@ -56,4 +60,4 @@ std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, c
}
} //namespace
-} //namespace armnn \ No newline at end of file
+} //namespace armnn