ArmNN
 22.02
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 #include "WorkingMemDescriptor.hpp"
11 
12 #include <Profiling.hpp>
13 #include <ProfilingService.hpp>
14 
15 #include <algorithm>
16 
17 namespace armnn
18 {
19 
20 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
21 // in the various workload factories.
22 // There should never be an instantiation of a NullWorkload.
23 class NullWorkload : public IWorkload
24 {
25  NullWorkload()=delete;
26 };
27 
28 template <typename QueueDescriptor>
29 class BaseWorkload : public IWorkload
30 {
31 public:
32 
33  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
34  : m_Data(descriptor),
35  m_Guid(profiling::ProfilingService::GetNextGuid())
36  {
37  m_Data.Validate(info);
38  }
39 
40  void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
41  {
42  ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
43  std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
44 
45  m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
46  m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
47 
48  Execute();
49  };
50 
51  void PostAllocationConfigure() override {}
52 
53  const QueueDescriptor& GetData() const { return m_Data; }
54 
55  profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
56 
57  virtual bool SupportsTensorHandleReplacement() const override
58  {
59  return false;
60  }
61 
62  // Replace input tensor handle with the given TensorHandle
63  void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
64  {
65  armnn::IgnoreUnused(tensorHandle, slot);
66  throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
67  }
68 
69  // Replace output tensor handle with the given TensorHandle
70  void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
71  {
72  armnn::IgnoreUnused(tensorHandle, slot);
73  throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
74  }
75 
76 protected:
78  const profiling::ProfilingGuid m_Guid;
79 
80 private:
81  std::mutex m_AsyncWorkloadMutex;
82 };
83 
84 // TypedWorkload used
85 template <typename QueueDescriptor, armnn::DataType... DataTypes>
86 class TypedWorkload : public BaseWorkload<QueueDescriptor>
87 {
88 public:
89 
90  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
91  : BaseWorkload<QueueDescriptor>(descriptor, info)
92  {
93  std::vector<armnn::DataType> dataTypes = {DataTypes...};
94  armnn::DataType expectedInputType;
95 
96  if (!info.m_InputTensorInfos.empty())
97  {
98  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
99 
100  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
101  {
102  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
103  }
104  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
105  info.m_InputTensorInfos.end(),
106  [&](auto it){
107  return it.GetDataType() == expectedInputType;
108  }),
109  "Trying to create workload with incorrect type");
110  }
111  armnn::DataType expectedOutputType;
112 
113  if (!info.m_OutputTensorInfos.empty())
114  {
115  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
116 
117  if (!info.m_InputTensorInfos.empty())
118  {
119  if (expectedOutputType != expectedInputType)
120  {
121  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
122  }
123  }
124  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
125  {
126  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
127  }
128  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
129  info.m_OutputTensorInfos.end(),
130  [&](auto it){
131  return it.GetDataType() == expectedOutputType;
132  }),
133  "Trying to create workload with incorrect type");
134  }
135  }
136 };
137 
138 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
139 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
140 {
141 public:
142 
144  : BaseWorkload<QueueDescriptor>(descriptor, info)
145  {
146  ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
147  info.m_InputTensorInfos.end(),
148  [&](auto it){
149  return it.GetDataType() == InputDataType;
150  }),
151  "Trying to create workload with incorrect type");
152 
153  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
154  info.m_OutputTensorInfos.end(),
155  [&](auto it){
156  return it.GetDataType() == OutputDataType;
157  }),
158  "Trying to create workload with incorrect type");
159  }
160 };
161 
162 // FirstInputTypedWorkload used to check type of the first input
163 template <typename QueueDescriptor, armnn::DataType DataType>
164 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
165 {
166 public:
167 
169  : BaseWorkload<QueueDescriptor>(descriptor, info)
170  {
171  if (!info.m_InputTensorInfos.empty())
172  {
173  ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
174  "Trying to create workload with incorrect type");
175  }
176 
177  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
178  info.m_OutputTensorInfos.end(),
179  [&](auto it){
180  return it.GetDataType() == DataType;
181  }),
182  "Trying to create workload with incorrect type");
183  }
184 };
185 
186 template <typename QueueDescriptor>
190 
191 template <typename QueueDescriptor>
193 
194 template <typename QueueDescriptor>
196 
197 template <typename QueueDescriptor>
199 
200 template <typename QueueDescriptor>
202 
203 template <typename QueueDescriptor>
204 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
207 
208 template <typename QueueDescriptor>
209 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
212 
213 template <typename QueueDescriptor>
214 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
216  armnn::DataType::Float32>;
217 
218 template <typename QueueDescriptor>
219 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
221  armnn::DataType::BFloat16>;
222 
223 template <typename QueueDescriptor>
224 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
226  armnn::DataType::Float32>;
227 
228 template <typename QueueDescriptor>
229 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
231  armnn::DataType::Float16>;
232 
233 template <typename QueueDescriptor>
234 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
236  armnn::DataType::Float32>;
237 
238 } //namespace armnn
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:143
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:63
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
Definition: Workload.hpp:40
const profiling::ProfilingGuid m_Guid
Definition: Workload.hpp:78
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:33
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:70
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
Definition: Workload.hpp:53
DataType
Definition: Types.hpp:35
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
QueueDescriptor m_Data
Definition: Workload.hpp:77
void PostAllocationConfigure() override
Definition: Workload.hpp:51
virtual bool SupportsTensorHandleReplacement() const override
Definition: Workload.hpp:57
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:22
profiling::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:55
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:168
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:90
Contains information about TensorInfos of a layer.
virtual void Execute() const =0