ArmNN
 21.11
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
10 
12 #include <Profiling.hpp>
13 #include <ProfilingService.hpp>
14 
15 #include <algorithm>
16 
17 namespace armnn
18 {
19 
20 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
21 // in the various workload factories.
22 // There should never be an instantiation of a NullWorkload.
23 class NullWorkload : public IWorkload
24 {
25  NullWorkload()=delete;
26 };
27 
28 template <typename QueueDescriptor>
29 class BaseWorkload : public IWorkload
30 {
31 public:
32 
33  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
34  : m_Data(descriptor),
35  m_Guid(profiling::ProfilingService::GetNextGuid())
36  {
37  m_Data.Validate(info);
38  }
39 
40  void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
41  {
42  ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
43  std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
44 
45  m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
46  m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
47 
48  Execute();
49  };
50 
51  void PostAllocationConfigure() override {}
52 
53  const QueueDescriptor& GetData() const { return m_Data; }
54 
55  profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
56 
57 protected:
59  const profiling::ProfilingGuid m_Guid;
60 
61 private:
62  std::mutex m_AsyncWorkloadMutex;
63 };
64 
65 // TypedWorkload used
66 template <typename QueueDescriptor, armnn::DataType... DataTypes>
67 class TypedWorkload : public BaseWorkload<QueueDescriptor>
68 {
69 public:
70 
71  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
72  : BaseWorkload<QueueDescriptor>(descriptor, info)
73  {
74  std::vector<armnn::DataType> dataTypes = {DataTypes...};
75  armnn::DataType expectedInputType;
76 
77  if (!info.m_InputTensorInfos.empty())
78  {
79  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
80 
81  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
82  {
83  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
84  }
85  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
86  info.m_InputTensorInfos.end(),
87  [&](auto it){
88  return it.GetDataType() == expectedInputType;
89  }),
90  "Trying to create workload with incorrect type");
91  }
92  armnn::DataType expectedOutputType;
93 
94  if (!info.m_OutputTensorInfos.empty())
95  {
96  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
97 
98  if (!info.m_InputTensorInfos.empty())
99  {
100  if (expectedOutputType != expectedInputType)
101  {
102  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
103  }
104  }
105  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
106  {
107  ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
108  }
109  ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
110  info.m_OutputTensorInfos.end(),
111  [&](auto it){
112  return it.GetDataType() == expectedOutputType;
113  }),
114  "Trying to create workload with incorrect type");
115  }
116  }
117 };
118 
119 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
120 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
121 {
122 public:
123 
125  : BaseWorkload<QueueDescriptor>(descriptor, info)
126  {
127  ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
128  info.m_InputTensorInfos.end(),
129  [&](auto it){
130  return it.GetDataType() == InputDataType;
131  }),
132  "Trying to create workload with incorrect type");
133 
134  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
135  info.m_OutputTensorInfos.end(),
136  [&](auto it){
137  return it.GetDataType() == OutputDataType;
138  }),
139  "Trying to create workload with incorrect type");
140  }
141 };
142 
143 // FirstInputTypedWorkload used to check type of the first input
144 template <typename QueueDescriptor, armnn::DataType DataType>
145 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
146 {
147 public:
148 
150  : BaseWorkload<QueueDescriptor>(descriptor, info)
151  {
152  if (!info.m_InputTensorInfos.empty())
153  {
154  ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
155  "Trying to create workload with incorrect type");
156  }
157 
158  ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
159  info.m_OutputTensorInfos.end(),
160  [&](auto it){
161  return it.GetDataType() == DataType;
162  }),
163  "Trying to create workload with incorrect type");
164  }
165 };
166 
167 template <typename QueueDescriptor>
171 
172 template <typename QueueDescriptor>
174 
175 template <typename QueueDescriptor>
177 
178 template <typename QueueDescriptor>
180 
181 template <typename QueueDescriptor>
183 
184 template <typename QueueDescriptor>
185 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
188 
189 template <typename QueueDescriptor>
190 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
193 
194 template <typename QueueDescriptor>
195 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
197  armnn::DataType::Float32>;
198 
199 template <typename QueueDescriptor>
200 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
202  armnn::DataType::BFloat16>;
203 
204 template <typename QueueDescriptor>
205 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
207  armnn::DataType::Float32>;
208 
209 template <typename QueueDescriptor>
210 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
212  armnn::DataType::Float16>;
213 
214 template <typename QueueDescriptor>
215 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
217  armnn::DataType::Float32>;
218 
219 } //namespace armnn
#define ARMNN_LOG(severity)
Definition: Logging.hpp:202
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:124
Copyright (c) 2021 ARM Limited and Contributors.
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
Definition: Workload.hpp:40
const profiling::ProfilingGuid m_Guid
Definition: Workload.hpp:59
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:33
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
Definition: Workload.hpp:53
DataType
Definition: Types.hpp:35
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
QueueDescriptor m_Data
Definition: Workload.hpp:58
void PostAllocationConfigure() override
Definition: Workload.hpp:51
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:22
profiling::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:55
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:149
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:71
Contains information about TensorInfos of a layer.
virtual void Execute() const =0