ArmNN
 23.08
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 #include "WorkingMemDescriptor.hpp"
11 #include "ExecutionData.hpp"
12 
13 #include <armnn/Logging.hpp>
14 
15 #include <Profiling.hpp>
16 
17 #include <client/include/IProfilingService.hpp>
18 
19 #include <algorithm>
20 
21 namespace armnn
22 {
23 
24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25 // in the various workload factories.
26 // There should never be an instantiation of a NullWorkload.
27 class NullWorkload : public IWorkload
28 {
29  NullWorkload()=delete;
30 };
31 
32 template <typename QueueDescriptor>
33 class BaseWorkload : public IWorkload
34 {
35 public:
36 
37  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
38  : m_Data(descriptor),
39  m_Guid(arm::pipe::IProfilingService::GetNextGuid()),
41  {
42  m_Data.Validate(info);
43  }
44 
45  virtual const std::string& GetName() const override
46  {
47  return m_Name;
48  }
49 
50  void ExecuteAsync(ExecutionData& executionData) override
51  {
52  ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
53 #if !defined(ARMNN_DISABLE_THREADS)
54  std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
55 #endif
56  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
57  m_Data.m_Inputs = workingMemDescriptor->m_Inputs;
58  m_Data.m_Outputs = workingMemDescriptor->m_Outputs;
59 
60  Execute();
61  };
62 
63  void PostAllocationConfigure() override {}
64 
65  const QueueDescriptor& GetData() const { return m_Data; }
66 
67  arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
68 
69  virtual bool SupportsTensorHandleReplacement() const override
70  {
71  return false;
72  }
73 
74  // Replace input tensor handle with the given TensorHandle
75  void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
76  {
77  armnn::IgnoreUnused(tensorHandle, slot);
78  throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
79  }
80 
81  // Replace output tensor handle with the given TensorHandle
82  void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
83  {
84  armnn::IgnoreUnused(tensorHandle, slot);
85  throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
86  }
87 
88 protected:
90  const arm::pipe::ProfilingGuid m_Guid;
91  const std::string m_Name;
92 
93 private:
94 #if !defined(ARMNN_DISABLE_THREADS)
95  std::mutex m_AsyncWorkloadMutex;
96 #endif
97 };
98 
99 // TypedWorkload used
100 template <typename QueueDescriptor, armnn::DataType... DataTypes>
101 class TypedWorkload : public BaseWorkload<QueueDescriptor>
102 {
103 public:
104 
105  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
106  : BaseWorkload<QueueDescriptor>(descriptor, info)
107  {
108  std::vector<armnn::DataType> dataTypes = {DataTypes...};
109  armnn::DataType expectedInputType;
110 
111  if (!info.m_InputTensorInfos.empty())
112  {
113  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
114 
115  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
116  {
117  throw armnn::Exception("Trying to create workload with incorrect type");
118  }
119  if (std::all_of(std::next(info.m_InputTensorInfos.begin()),
120  info.m_InputTensorInfos.end(),
121  [&](auto it){
122  return it.GetDataType() == expectedInputType;
123  }) == false)
124  {
125  throw armnn::Exception("Trying to create workload with incorrect type");
126  }
127  }
128  armnn::DataType expectedOutputType;
129 
130  if (!info.m_OutputTensorInfos.empty())
131  {
132  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
133 
134  if (!info.m_InputTensorInfos.empty())
135  {
136  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
137 
138  if (expectedOutputType != expectedInputType)
139  {
140  throw armnn::Exception( "Trying to create workload with incorrect type");
141  }
142  }
143  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
144  {
145  throw armnn::Exception("Trying to create workload with incorrect type");
146  }
147  if (std::all_of(std::next(info.m_OutputTensorInfos.begin()),
148  info.m_OutputTensorInfos.end(),
149  [&](auto it){
150  return it.GetDataType() == expectedOutputType;
151  }) == false)
152  {
153  throw armnn::Exception("Trying to create workload with incorrect type");
154  }
155  }
156  }
157 };
158 
159 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
160 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
161 {
162 public:
163 
164  MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
165  : BaseWorkload<QueueDescriptor>(descriptor, info)
166  {
167  if (std::all_of(info.m_InputTensorInfos.begin(),
168  info.m_InputTensorInfos.end(),
169  [&](auto it){
170  return it.GetDataType() == InputDataType;
171  }) == false)
172  {
173  throw armnn::Exception("Trying to create workload with incorrect type");
174  }
175  if (std::all_of(info.m_OutputTensorInfos.begin(),
176  info.m_OutputTensorInfos.end(),
177  [&](auto it){
178  return it.GetDataType() == OutputDataType;
179  }) == false)
180  {
181  throw armnn::Exception("Trying to create workload with incorrect type");
182  }
183  }
184 };
185 
186 // FirstInputTypedWorkload used to check type of the first input
187 template <typename QueueDescriptor, armnn::DataType DataType>
188 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
189 {
190 public:
191 
192  FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
193  : BaseWorkload<QueueDescriptor>(descriptor, info)
194  {
195  if (!info.m_InputTensorInfos.empty())
196  {
197  if (info.m_InputTensorInfos.front().GetDataType() != DataType)
198  {
199  throw armnn::Exception("Trying to create workload with incorrect type");
200  }
201  }
202 
203  if (std::all_of(info.m_OutputTensorInfos.begin(),
204  info.m_OutputTensorInfos.end(),
205  [&](auto it){
206  return it.GetDataType() == DataType;
207  }) == false)
208  {
209  throw armnn::Exception("Trying to create workload with incorrect type");
210  }
211  }
212 };
213 
214 template <typename QueueDescriptor>
215 using FloatWorkload = TypedWorkload<QueueDescriptor,
218 
219 template <typename QueueDescriptor>
221 
222 template <typename QueueDescriptor>
224 
225 template <typename QueueDescriptor>
227 
228 template <typename QueueDescriptor>
230 
231 template <typename QueueDescriptor>
235 
236 template <typename QueueDescriptor>
240 
241 template <typename QueueDescriptor>
245 
246 template <typename QueueDescriptor>
250 
251 template <typename QueueDescriptor>
255 
256 template <typename QueueDescriptor>
260 
261 template <typename QueueDescriptor>
265 
266 } //namespace armnn
armnn::BaseWorkload::BaseWorkload
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:37
armnn::DataType::Boolean
@ Boolean
arm
Definition: BackendRegistry.hpp:15
armnn::BaseWorkload::SupportsTensorHandleReplacement
virtual bool SupportsTensorHandleReplacement() const override
Definition: Workload.hpp:69
armnn::BaseWorkload::GetData
const QueueDescriptor & GetData() const
Definition: Workload.hpp:65
WorkloadData.hpp
ExecutionData.hpp
armnn::experimental::ExecutionData::m_Data
void * m_Data
Definition: ExecutionData.hpp:16
armnn::BaseWorkload::ReplaceOutputTensorHandle
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:82
armnn::IWorkload
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:23
armnn::BaseWorkload::m_Name
const std::string m_Name
Definition: Workload.hpp:91
Profiling.hpp
armnn::DataType::Float32
@ Float32
armnn::TypedWorkload::TypedWorkload
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:105
armnn::ITensorHandle
Definition: ITensorHandle.hpp:16
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::TypedWorkload
Definition: Workload.hpp:101
armnn::BaseWorkload::GetName
virtual const std::string & GetName() const override
Definition: Workload.hpp:45
armnn::DataType::BFloat16
@ BFloat16
armnn::IWorkload::Execute
virtual void Execute() const =0
armnn::BaseWorkload::ExecuteAsync
void ExecuteAsync(ExecutionData &executionData) override
Definition: Workload.hpp:50
ARMNN_LOG
#define ARMNN_LOG(severity)
Definition: Logging.hpp:212
IWorkload.hpp
armnn::DataType::Float16
@ Float16
Logging.hpp
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
armnn::DataType
DataType
Definition: Types.hpp:48
armnn::FloatWorkload
TypedWorkload< QueueDescriptor, armnn::DataType::Float16, armnn::DataType::Float32 > FloatWorkload
Definition: Workload.hpp:217
armnn::QueueDescriptor
Definition: WorkloadData.hpp:24
WorkingMemDescriptor.hpp
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::BoostLogSeverityMapping::info
@ info
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::BaseWorkload
Definition: Workload.hpp:33
armnn::FirstInputTypedWorkload
Definition: Workload.hpp:188
armnn::BaseWorkload::GetGuid
arm::pipe::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:67
armnn::BaseWorkload::PostAllocationConfigure
void PostAllocationConfigure() override
Definition: Workload.hpp:63
armnn::BaseWorkload::m_Data
QueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn::BaseWorkload::m_Guid
const arm::pipe::ProfilingGuid m_Guid
Definition: Workload.hpp:90
armnn::experimental::WorkingMemDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkingMemDescriptor.hpp:20
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::experimental::WorkingMemDescriptor
Definition: WorkingMemDescriptor.hpp:18
armnn::NullWorkload
Definition: Workload.hpp:27
armnn::BaseWorkload::ReplaceInputTensorHandle
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:75
armnn::UnimplementedException
Definition: Exceptions.hpp:98
armnn::MultiTypedWorkload::MultiTypedWorkload
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:164
armnn::FirstInputTypedWorkload::FirstInputTypedWorkload
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:192
WorkloadInfo.hpp
armnn::experimental::WorkingMemDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkingMemDescriptor.hpp:21
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::MultiTypedWorkload
Definition: Workload.hpp:160
armnn::experimental::ExecutionData
Definition: ExecutionData.hpp:14