ArmNN
 22.05
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 #include "WorkingMemDescriptor.hpp"
11 
12 #include <armnn/Logging.hpp>
13 
14 #include <Profiling.hpp>
15 
16 #include <client/include/IProfilingService.hpp>
17 
18 #include <algorithm>
19 
20 namespace armnn
21 {
22 
23 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
24 // in the various workload factories.
25 // There should never be an instantiation of a NullWorkload.
26 class NullWorkload : public IWorkload
27 {
28  NullWorkload()=delete;
29 };
30 
31 template <typename QueueDescriptor>
32 class BaseWorkload : public IWorkload
33 {
34 public:
35 
36  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
37  : m_Data(descriptor),
38  m_Guid(arm::pipe::IProfilingService::GetNextGuid())
39  {
40  m_Data.Validate(info);
41  }
42 
43  void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
44  {
45  ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
46 #if !defined(ARMNN_DISABLE_THREADS)
47  std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
48 #endif
49  m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
50  m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
51 
52  Execute();
53  };
54 
55  void PostAllocationConfigure() override {}
56 
57  const QueueDescriptor& GetData() const { return m_Data; }
58 
59  arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
60 
61  virtual bool SupportsTensorHandleReplacement() const override
62  {
63  return false;
64  }
65 
66  // Replace input tensor handle with the given TensorHandle
67  void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
68  {
69  armnn::IgnoreUnused(tensorHandle, slot);
70  throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
71  }
72 
73  // Replace output tensor handle with the given TensorHandle
74  void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
75  {
76  armnn::IgnoreUnused(tensorHandle, slot);
77  throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
78  }
79 
80 protected:
82  const arm::pipe::ProfilingGuid m_Guid;
83 
84 private:
85 #if !defined(ARMNN_DISABLE_THREADS)
86  std::mutex m_AsyncWorkloadMutex;
87 #endif
88 };
89 
90 // TypedWorkload used
91 template <typename QueueDescriptor, armnn::DataType... DataTypes>
92 class TypedWorkload : public BaseWorkload<QueueDescriptor>
93 {
94 public:
95 
96  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
97  : BaseWorkload<QueueDescriptor>(descriptor, info)
98  {
99  std::vector<armnn::DataType> dataTypes = {DataTypes...};
100  armnn::DataType expectedInputType;
101 
102  if (!info.m_InputTensorInfos.empty())
103  {
104  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
105 
106  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
107  {
108  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
109  }
110  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
111  info.m_InputTensorInfos.end(),
112  [&](auto it){
113  return it.GetDataType() == expectedInputType;
114  }),
115  "Trying to create workload with incorrect type");
116  }
117  armnn::DataType expectedOutputType;
118 
119  if (!info.m_OutputTensorInfos.empty())
120  {
121  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
122 
123  if (!info.m_InputTensorInfos.empty())
124  {
125  if (expectedOutputType != expectedInputType)
126  {
127  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
128  }
129  }
130  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
131  {
132  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
133  }
134  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
135  info.m_OutputTensorInfos.end(),
136  [&](auto it){
137  return it.GetDataType() == expectedOutputType;
138  }),
139  "Trying to create workload with incorrect type");
140  }
141  }
142 };
143 
144 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
145 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
146 {
147 public:
148 
150  : BaseWorkload<QueueDescriptor>(descriptor, info)
151  {
152  ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
153  info.m_InputTensorInfos.end(),
154  [&](auto it){
155  return it.GetDataType() == InputDataType;
156  }),
157  "Trying to create workload with incorrect type");
158 
159  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
160  info.m_OutputTensorInfos.end(),
161  [&](auto it){
162  return it.GetDataType() == OutputDataType;
163  }),
164  "Trying to create workload with incorrect type");
165  }
166 };
167 
168 // FirstInputTypedWorkload used to check type of the first input
169 template <typename QueueDescriptor, armnn::DataType DataType>
170 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
171 {
172 public:
173 
175  : BaseWorkload<QueueDescriptor>(descriptor, info)
176  {
177  if (!info.m_InputTensorInfos.empty())
178  {
179  ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
180  "Trying to create workload with incorrect type");
181  }
182 
183  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
184  info.m_OutputTensorInfos.end(),
185  [&](auto it){
186  return it.GetDataType() == DataType;
187  }),
188  "Trying to create workload with incorrect type");
189  }
190 };
191 
192 template <typename QueueDescriptor>
196 
197 template <typename QueueDescriptor>
199 
200 template <typename QueueDescriptor>
202 
203 template <typename QueueDescriptor>
205 
206 template <typename QueueDescriptor>
208 
209 template <typename QueueDescriptor>
210 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
213 
214 template <typename QueueDescriptor>
215 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
218 
219 template <typename QueueDescriptor>
220 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
222  armnn::DataType::Float32>;
223 
224 template <typename QueueDescriptor>
225 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
227  armnn::DataType::BFloat16>;
228 
229 template <typename QueueDescriptor>
230 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
232  armnn::DataType::Float32>;
233 
234 template <typename QueueDescriptor>
235 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
237  armnn::DataType::Float16>;
238 
239 template <typename QueueDescriptor>
240 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
242  armnn::DataType::Float32>;
243 
244 } //namespace armnn
arm::pipe::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:59
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:149
const arm::pipe::ProfilingGuid m_Guid
Definition: Workload.hpp:82
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:67
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
Definition: Workload.hpp:43
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:36
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:74
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
Definition: Workload.hpp:57
DataType
Definition: Types.hpp:48
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
QueueDescriptor m_Data
Definition: Workload.hpp:81
void PostAllocationConfigure() override
Definition: Workload.hpp:55
virtual bool SupportsTensorHandleReplacement() const override
Definition: Workload.hpp:61
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:22
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:174
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:96
Contains information about TensorInfos of a layer.
virtual void Execute() const =0