28 template <
typename QueueDescriptor>
37 m_Data.Validate(info);
42 ARMNN_LOG(
info) <<
"Using default async workload execution, this will network affect performance";
43 std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
45 m_Data.m_Inputs = workingMemDescriptor.
m_Inputs;
46 m_Data.m_Outputs = workingMemDescriptor.
m_Outputs;
55 profiling::ProfilingGuid
GetGuid() const final {
return m_Guid; }
59 const profiling::ProfilingGuid
m_Guid;
62 std::mutex m_AsyncWorkloadMutex;
74 std::vector<armnn::DataType> dataTypes = {DataTypes...};
81 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
88 return it.GetDataType() == expectedInputType;
90 "Trying to create workload with incorrect type");
100 if (expectedOutputType != expectedInputType)
105 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
112 return it.GetDataType() == expectedOutputType;
114 "Trying to create workload with incorrect type");
119 template <
typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
130 return it.GetDataType() == InputDataType;
132 "Trying to create workload with incorrect type");
137 return it.GetDataType() == OutputDataType;
139 "Trying to create workload with incorrect type");
144 template <
typename QueueDescriptor, armnn::DataType DataType>
155 "Trying to create workload with incorrect type");
161 return it.GetDataType() ==
DataType;
163 "Trying to create workload with incorrect type");
167 template <
typename QueueDescriptor>
172 template <
typename QueueDescriptor>
175 template <
typename QueueDescriptor>
178 template <
typename QueueDescriptor>
181 template <
typename QueueDescriptor>
184 template <
typename QueueDescriptor>
189 template <
typename QueueDescriptor>
194 template <
typename QueueDescriptor>
197 armnn::DataType::Float32>;
199 template <
typename QueueDescriptor>
202 armnn::DataType::BFloat16>;
204 template <
typename QueueDescriptor>
207 armnn::DataType::Float32>;
209 template <
typename QueueDescriptor>
212 armnn::DataType::Float16>;
214 template <
typename QueueDescriptor>
217 armnn::DataType::Float32>;
#define ARMNN_LOG(severity)
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Copyright (c) 2021 ARM Limited and Contributors.
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
const profiling::ProfilingGuid m_Guid
std::vector< ITensorHandle * > m_Inputs
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
std::vector< TensorInfo > m_InputTensorInfos
const QueueDescriptor & GetData() const
#define ARMNN_ASSERT_MSG(COND, MSG)
void PostAllocationConfigure() override
std::vector< TensorInfo > m_OutputTensorInfos
Workload interface to enqueue a layer computation.
profiling::ProfilingGuid GetGuid() const final
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.
virtual void Execute() const =0