aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/MakeWorkloadHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/MakeWorkloadHelper.hpp')
-rw-r--r--src/armnn/backends/MakeWorkloadHelper.hpp19
1 files changed, 17 insertions, 2 deletions
diff --git a/src/armnn/backends/MakeWorkloadHelper.hpp b/src/armnn/backends/MakeWorkloadHelper.hpp
index a1f9b0b0eb..64a7f8983b 100644
--- a/src/armnn/backends/MakeWorkloadHelper.hpp
+++ b/src/armnn/backends/MakeWorkloadHelper.hpp
@@ -9,7 +9,7 @@ namespace armnn
namespace
{
-// Make a workload of the specified WorkloadType
+// Make a workload of the specified WorkloadType.
template<typename WorkloadType>
struct MakeWorkloadForType
{
@@ -37,7 +37,8 @@ struct MakeWorkloadForType<NullWorkload>
// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
-template <typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
+template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType,
+ typename... Args>
std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, Args&&... args)
{
const DataType dataType = !info.m_InputTensorInfos.empty() ?
@@ -49,6 +50,8 @@ std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, c
switch (dataType)
{
+ case DataType::Float16:
+ return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Float32:
return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::QuantisedAsymm8:
@@ -59,5 +62,17 @@ std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, c
}
}
+// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
+// Calling this method is the equivalent of calling the three typed MakeWorkload method with <FloatWorkload,
+// FloatWorkload, Uint8Workload>.
+// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
+template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
+std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, Args&&... args)
+{
+ return MakeWorkload<FloatWorkload, FloatWorkload, Uint8Workload>(descriptor, info,
+ std::forward<Args>(args)...);
+}
+
+
} //namespace
} //namespace armnn