ArmNN
 22.08
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 #include "WorkingMemDescriptor.hpp"
11 #include "ExecutionData.hpp"
12 
13 #include <armnn/Logging.hpp>
14 
15 #include <Profiling.hpp>
16 
17 #include <client/include/IProfilingService.hpp>
18 
19 #include <algorithm>
20 
21 namespace armnn
22 {
23 
24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25 // in the various workload factories.
26 // There should never be an instantiation of a NullWorkload.
27 class NullWorkload : public IWorkload
28 {
29  NullWorkload()=delete;
30 };
31 
32 template <typename QueueDescriptor>
33 class BaseWorkload : public IWorkload
34 {
35 public:
36 
37  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
38  : m_Data(descriptor),
39  m_Guid(arm::pipe::IProfilingService::GetNextGuid())
40  {
41  m_Data.Validate(info);
42  }
43 
44  void ExecuteAsync(ExecutionData& executionData) override
45  {
46  ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
47 #if !defined(ARMNN_DISABLE_THREADS)
48  std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
49 #endif
50  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
51  m_Data.m_Inputs = workingMemDescriptor->m_Inputs;
52  m_Data.m_Outputs = workingMemDescriptor->m_Outputs;
53 
54  Execute();
55  };
56 
57  void PostAllocationConfigure() override {}
58 
59  const QueueDescriptor& GetData() const { return m_Data; }
60 
61  arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
62 
63  virtual bool SupportsTensorHandleReplacement() const override
64  {
65  return false;
66  }
67 
68  // Replace input tensor handle with the given TensorHandle
69  void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
70  {
71  armnn::IgnoreUnused(tensorHandle, slot);
72  throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
73  }
74 
75  // Replace output tensor handle with the given TensorHandle
76  void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
77  {
78  armnn::IgnoreUnused(tensorHandle, slot);
79  throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
80  }
81 
82 protected:
84  const arm::pipe::ProfilingGuid m_Guid;
85 
86 private:
87 #if !defined(ARMNN_DISABLE_THREADS)
88  std::mutex m_AsyncWorkloadMutex;
89 #endif
90 };
91 
92 // TypedWorkload used
93 template <typename QueueDescriptor, armnn::DataType... DataTypes>
94 class TypedWorkload : public BaseWorkload<QueueDescriptor>
95 {
96 public:
97 
98  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
99  : BaseWorkload<QueueDescriptor>(descriptor, info)
100  {
101  std::vector<armnn::DataType> dataTypes = {DataTypes...};
102  armnn::DataType expectedInputType;
103 
104  if (!info.m_InputTensorInfos.empty())
105  {
106  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
107 
108  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
109  {
110  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
111  }
112  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
113  info.m_InputTensorInfos.end(),
114  [&](auto it){
115  return it.GetDataType() == expectedInputType;
116  }),
117  "Trying to create workload with incorrect type");
118  }
119  armnn::DataType expectedOutputType;
120 
121  if (!info.m_OutputTensorInfos.empty())
122  {
123  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
124 
125  if (!info.m_InputTensorInfos.empty())
126  {
127  if (expectedOutputType != expectedInputType)
128  {
129  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
130  }
131  }
132  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
133  {
134  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
135  }
136  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
137  info.m_OutputTensorInfos.end(),
138  [&](auto it){
139  return it.GetDataType() == expectedOutputType;
140  }),
141  "Trying to create workload with incorrect type");
142  }
143  }
144 };
145 
146 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
147 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
148 {
149 public:
150 
152  : BaseWorkload<QueueDescriptor>(descriptor, info)
153  {
154  ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
155  info.m_InputTensorInfos.end(),
156  [&](auto it){
157  return it.GetDataType() == InputDataType;
158  }),
159  "Trying to create workload with incorrect type");
160 
161  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
162  info.m_OutputTensorInfos.end(),
163  [&](auto it){
164  return it.GetDataType() == OutputDataType;
165  }),
166  "Trying to create workload with incorrect type");
167  }
168 };
169 
170 // FirstInputTypedWorkload used to check type of the first input
171 template <typename QueueDescriptor, armnn::DataType DataType>
172 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
173 {
174 public:
175 
177  : BaseWorkload<QueueDescriptor>(descriptor, info)
178  {
179  if (!info.m_InputTensorInfos.empty())
180  {
181  ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
182  "Trying to create workload with incorrect type");
183  }
184 
185  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
186  info.m_OutputTensorInfos.end(),
187  [&](auto it){
188  return it.GetDataType() == DataType;
189  }),
190  "Trying to create workload with incorrect type");
191  }
192 };
193 
194 template <typename QueueDescriptor>
198 
199 template <typename QueueDescriptor>
201 
202 template <typename QueueDescriptor>
204 
205 template <typename QueueDescriptor>
207 
208 template <typename QueueDescriptor>
210 
211 template <typename QueueDescriptor>
212 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
215 
216 template <typename QueueDescriptor>
217 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
220 
221 template <typename QueueDescriptor>
222 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
224  armnn::DataType::Float32>;
225 
226 template <typename QueueDescriptor>
227 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
229  armnn::DataType::BFloat16>;
230 
231 template <typename QueueDescriptor>
232 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
234  armnn::DataType::Float32>;
235 
236 template <typename QueueDescriptor>
237 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
239  armnn::DataType::Float16>;
240 
241 template <typename QueueDescriptor>
242 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
244  armnn::DataType::Float32>;
245 
246 } //namespace armnn
arm::pipe::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:61
void ExecuteAsync(ExecutionData &executionData) override
Definition: Workload.hpp:44
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:151
const arm::pipe::ProfilingGuid m_Guid
Definition: Workload.hpp:84
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:69
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:37
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:76
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
Definition: Workload.hpp:59
DataType
Definition: Types.hpp:48
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
QueueDescriptor m_Data
Definition: Workload.hpp:83
void PostAllocationConfigure() override
Definition: Workload.hpp:57
virtual bool SupportsTensorHandleReplacement() const override
Definition: Workload.hpp:63
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:23
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:176
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:98
Contains information about TensorInfos of a layer.
virtual void Execute() const =0