ArmNN
 21.02
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
9 
11 #include <Profiling.hpp>
12 #include <ProfilingService.hpp>
13 
14 #include <algorithm>
15 
16 namespace armnn
17 {
18 
19 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
20 // in the various workload factories.
21 // There should never be an instantiation of a NullWorkload.
22 class NullWorkload : public IWorkload
23 {
24  NullWorkload()=delete;
25 };
26 
27 template <typename QueueDescriptor>
28 class BaseWorkload : public IWorkload
29 {
30 public:
31 
32  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
33  : m_Data(descriptor),
34  m_Guid(profiling::ProfilingService::GetNextGuid())
35  {
36  m_Data.Validate(info);
37  }
38 
39  void PostAllocationConfigure() override {}
40 
41  const QueueDescriptor& GetData() const { return m_Data; }
42 
43  profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
44 
45 protected:
48 };
49 
50 // TypedWorkload used
51 template <typename QueueDescriptor, armnn::DataType... DataTypes>
52 class TypedWorkload : public BaseWorkload<QueueDescriptor>
53 {
54 public:
55 
56  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
57  : BaseWorkload<QueueDescriptor>(descriptor, info)
58  {
59  std::vector<armnn::DataType> dataTypes = {DataTypes...};
60  armnn::DataType expectedInputType;
61 
62  if (!info.m_InputTensorInfos.empty())
63  {
64  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
65 
66  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
67  {
68  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
69  }
70  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
71  info.m_InputTensorInfos.end(),
72  [&](auto it){
73  return it.GetDataType() == expectedInputType;
74  }),
75  "Trying to create workload with incorrect type");
76  }
77  armnn::DataType expectedOutputType;
78 
79  if (!info.m_OutputTensorInfos.empty())
80  {
81  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
82 
83  if (!info.m_InputTensorInfos.empty())
84  {
85  if (expectedOutputType != expectedInputType)
86  {
87  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
88  }
89  }
90  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
91  {
92  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
93  }
94  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
95  info.m_OutputTensorInfos.end(),
96  [&](auto it){
97  return it.GetDataType() == expectedOutputType;
98  }),
99  "Trying to create workload with incorrect type");
100  }
101  }
102 };
103 
104 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
105 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
106 {
107 public:
108 
110  : BaseWorkload<QueueDescriptor>(descriptor, info)
111  {
112  ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
113  info.m_InputTensorInfos.end(),
114  [&](auto it){
115  return it.GetDataType() == InputDataType;
116  }),
117  "Trying to create workload with incorrect type");
118 
119  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120  info.m_OutputTensorInfos.end(),
121  [&](auto it){
122  return it.GetDataType() == OutputDataType;
123  }),
124  "Trying to create workload with incorrect type");
125  }
126 };
127 
128 // FirstInputTypedWorkload used to check type of the first input
129 template <typename QueueDescriptor, armnn::DataType DataType>
130 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
131 {
132 public:
133 
135  : BaseWorkload<QueueDescriptor>(descriptor, info)
136  {
137  if (!info.m_InputTensorInfos.empty())
138  {
139  ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
140  "Trying to create workload with incorrect type");
141  }
142 
143  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
144  info.m_OutputTensorInfos.end(),
145  [&](auto it){
146  return it.GetDataType() == DataType;
147  }),
148  "Trying to create workload with incorrect type");
149  }
150 };
151 
152 template <typename QueueDescriptor>
156 
157 template <typename QueueDescriptor>
159 
160 template <typename QueueDescriptor>
162 
163 template <typename QueueDescriptor>
165 
166 template <typename QueueDescriptor>
168 
169 template <typename QueueDescriptor>
170 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
173 
174 template <typename QueueDescriptor>
175 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
178 
179 template <typename QueueDescriptor>
180 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
182  armnn::DataType::Float32>;
183 
184 template <typename QueueDescriptor>
185 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
187  armnn::DataType::BFloat16>;
188 
189 template <typename QueueDescriptor>
190 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
192  armnn::DataType::Float32>;
193 
194 template <typename QueueDescriptor>
195 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
197  armnn::DataType::Float16>;
198 
199 template <typename QueueDescriptor>
200 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
202  armnn::DataType::Float32>;
203 
204 } //namespace armnn
const QueueDescriptor m_Data
Definition: Workload.hpp:46
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:109
Copyright (c) 2021 ARM Limited and Contributors.
const profiling::ProfilingGuid m_Guid
Definition: Workload.hpp:47
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:32
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
Definition: Workload.hpp:41
DataType
Definition: Types.hpp:32
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
void PostAllocationConfigure() override
Definition: Workload.hpp:39
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:13
profiling::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:43
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:134
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:56
Contains information about inputs and outputs to a layer.