ArmNN
 21.11
RefWorkloadFactory.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <Layer.hpp>
10 #include "RefWorkloadFactory.hpp"
11 #include "RefBackendId.hpp"
13 #include "RefTensorHandle.hpp"
14 
15 
16 namespace armnn
17 {
18 
19 namespace
20 {
21 static const BackendId s_Id{RefBackendId()};
22 }
23 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
24 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
25  const WorkloadInfo& info) const
26 {
27  return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
28  (descriptor, info);
29 }
30 
31 template <DataType ArmnnType>
33 {
34  auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
35  auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
36  if (it != std::end(info.m_InputTensorInfos))
37  {
38  return true;
39  }
40  it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
41  if (it != std::end(info.m_OutputTensorInfos))
42  {
43  return true;
44  }
45  return false;
46 }
47 
49 {
50  return IsDataType<DataType::Signed32>(info);
51 }
52 
54 {
55  return IsDataType<DataType::BFloat16>(info);
56 }
57 
59 {
60  return IsDataType<DataType::Float16>(info);
61 }
62 
64 {
65  return IsDataType<DataType::QSymmS16>(info);
66 }
67 
69 {
70  return IsDataType<DataType::QSymmS8>(info);
71 }
72 
74 {
75  return IsDataType<DataType::QAsymmS8>(info);
76 }
77 
79 {
80  return IsDataType<DataType::QAsymmU8>(info);
81 }
82 
83 RefWorkloadFactory::RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager)
84  : m_MemoryManager(memoryManager)
85 {
86 }
87 
89  : m_MemoryManager(new RefMemoryManager())
90 {
91 }
92 
94 {
95  return s_Id;
96 }
97 
99  Optional<DataType> dataType,
100  std::string& outReasonIfUnsupported)
101 {
102  return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
103 }
104 
106  Optional<DataType> dataType,
107  std::string& outReasonIfUnsupported,
108  const ModelOptions& modelOptions)
109 {
110  return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
111 }
112 
113 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
114  const bool isMemoryManaged) const
115 {
116  if (isMemoryManaged)
117  {
118  return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
119  }
120  else
121  {
122  return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
123  }
124 }
125 
126 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
127  DataLayout dataLayout,
128  const bool isMemoryManaged) const
129 {
130  // For Ref it is okay to make the TensorHandle memory managed as it can also store a pointer
131  // to unmanaged memory. This also ensures memory alignment.
132  IgnoreUnused(isMemoryManaged, dataLayout);
133 
134  if (isMemoryManaged)
135  {
136  return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
137  }
138  else
139  {
140  return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
141  }
142 }
143 
144 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
145  const WorkloadInfo& info) const
146 {
147  return std::make_unique<RefActivationWorkload>(descriptor, info);
148 }
149 
150 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
151  const WorkloadInfo& info) const
152 {
153  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
154  {
155  return std::make_unique<RefAdditionWorkload<int32_t>>(descriptor, info);
156  }
157  else
158  {
159  return std::make_unique<RefAdditionWorkload<float>>(descriptor, info);
160  }
161 }
162 
163 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
164  const WorkloadInfo& info) const
165 {
166  return std::make_unique<RefArgMinMaxWorkload>(descriptor, info);
167 }
168 
170  const BatchNormalizationQueueDescriptor& descriptor,
171  const WorkloadInfo& info) const
172 {
173  return std::make_unique<RefBatchNormalizationWorkload>(descriptor, info);
174 }
175 
176 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
177  const WorkloadInfo& info) const
178 {
179  return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info);
180 }
181 
182 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateCast(const CastQueueDescriptor& descriptor,
183  const WorkloadInfo& info) const
184 {
185  return std::make_unique<RefCastWorkload>(descriptor, info);
186 }
187 
188 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor &descriptor,
189  const WorkloadInfo &info) const
190 {
191  return std::make_unique<RefChannelShuffleWorkload>(descriptor,info);
192 }
193 
194 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
195  const WorkloadInfo& info) const
196 {
197  return std::make_unique<RefComparisonWorkload>(descriptor, info);
198 }
199 
200 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
201  const WorkloadInfo& info) const
202 {
203  return std::make_unique<RefConcatWorkload>(descriptor, info);
204 }
205 
206 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
207  const WorkloadInfo& info) const
208 {
209  return std::make_unique<RefConstantWorkload>(descriptor, info);
210 }
211 
213  const ConvertBf16ToFp32QueueDescriptor& descriptor,
214  const WorkloadInfo& info) const
215 {
216  return std::make_unique<RefConvertBf16ToFp32Workload>(descriptor, info);
217 }
218 
220  const ConvertFp16ToFp32QueueDescriptor& descriptor,
221  const WorkloadInfo& info) const
222 {
223  return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
224 }
225 
227  const ConvertFp32ToBf16QueueDescriptor& descriptor,
228  const WorkloadInfo& info) const
229 {
230  return std::make_unique<RefConvertFp32ToBf16Workload>(descriptor, info);
231 }
232 
234  const ConvertFp32ToFp16QueueDescriptor& descriptor,
235  const WorkloadInfo& info) const
236 {
237  return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
238 }
239 
240 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
241  const WorkloadInfo& info) const
242 {
243  return std::make_unique<RefConvolution2dWorkload>(descriptor, info);
244 }
245 
246 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& descriptor,
247  const WorkloadInfo& info) const
248 {
249  return std::make_unique<RefConvolution3dWorkload>(descriptor, info);
250 }
251 
252 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
253  const WorkloadInfo& info) const
254 {
255  if (IsBFloat16(info))
256  {
257  return std::make_unique<RefDebugBFloat16Workload>(descriptor, info);
258  }
259  if (IsFloat16(info))
260  {
261  return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
262  }
263  if (IsQSymmS16(info))
264  {
265  return std::make_unique<RefDebugQSymmS16Workload>(descriptor, info);
266  }
267  if (IsQSymmS8(info))
268  {
269  return std::make_unique<RefDebugQSymmS8Workload>(descriptor, info);
270  }
271  if (IsQAsymmU8(info))
272  {
273  return std::make_unique<RefDebugQAsymmU8Workload>(descriptor, info);
274  }
275  if (IsQAsymmS8(info))
276  {
277  return std::make_unique<RefDebugQAsymmS8Workload>(descriptor, info);
278  }
279  if (IsSigned32(info))
280  {
281  return std::make_unique<RefDebugSigned32Workload>(descriptor, info);
282  }
283 
284  return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymmU8Workload>(descriptor, info);
285 }
286 
287 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
288  const WorkloadInfo& info) const
289 {
290  return std::make_unique<RefDepthToSpaceWorkload>(descriptor, info);
291 }
292 
294  const DepthwiseConvolution2dQueueDescriptor& descriptor,
295  const WorkloadInfo& info) const
296 {
297  return std::make_unique<RefDepthwiseConvolution2dWorkload>(descriptor, info);
298 }
299 
300 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
301  const WorkloadInfo& info) const
302 {
303  return std::make_unique<RefDequantizeWorkload>(descriptor, info);
304 }
305 
307  const DetectionPostProcessQueueDescriptor& descriptor,
308  const WorkloadInfo& info) const
309 {
310  return std::make_unique<RefDetectionPostProcessWorkload>(descriptor, info);
311 }
312 
313 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
314  const WorkloadInfo& info) const
315 {
316  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
317  {
318  return std::make_unique<RefDivisionWorkload<int32_t>>(descriptor, info);
319  }
320  else
321  {
322  return std::make_unique<RefDivisionWorkload<float>>(descriptor, info);
323  }
324 }
325 
327  const WorkloadInfo& info) const
328 {
330  {
331  return std::make_unique<RefLogicalUnaryWorkload>(descriptor, info);
332  }
333  return std::make_unique<RefElementwiseUnaryWorkload>(descriptor, info);
334 }
335 
337  const WorkloadInfo& info) const
338 {
339  return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
340 }
341 
342 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFill(const FillQueueDescriptor& descriptor,
343  const WorkloadInfo& info) const
344 {
345  return std::make_unique<RefFillWorkload>(descriptor, info);
346 }
347 
348 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
349  const WorkloadInfo& info) const
350 {
351  if(IsQuantizedType(info.m_InputTensorInfos[0].GetDataType()))
352  {
353  return nullptr;
354  }
355  else
356  {
357  return std::make_unique<RefFloorWorkload>(descriptor, info);
358  }
359 }
360 
361 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFullyConnected(
362  const FullyConnectedQueueDescriptor& descriptor,
363  const WorkloadInfo& info) const
364 {
365  return std::make_unique<RefFullyConnectedWorkload>(descriptor, info);
366 }
367 
368 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
369  const WorkloadInfo& info) const
370 {
371  return std::make_unique<RefGatherWorkload>(descriptor, info);
372 }
373 
374 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
375  const WorkloadInfo& info) const
376 {
377  if (info.m_InputTensorInfos.empty() )
378  {
379  throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
380  }
381  if (info.m_OutputTensorInfos.empty())
382  {
383  throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
384  }
385 
386  if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
387  {
388  throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
389  }
390 
391  return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
392 }
393 
395  const InstanceNormalizationQueueDescriptor& descriptor,
396  const WorkloadInfo& info) const
397 {
398  return std::make_unique<RefInstanceNormalizationWorkload>(descriptor, info);
399 }
400 
402  const WorkloadInfo& info) const
403 {
404  return std::make_unique<RefL2NormalizationWorkload>(descriptor, info);
405 }
406 
407 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
408  const WorkloadInfo& info) const
409 {
410  return std::make_unique<RefLogicalBinaryWorkload>(descriptor, info);
411 }
412 
413 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
414  const WorkloadInfo& info) const
415 {
416  return std::make_unique<RefLogSoftmaxWorkload>(descriptor, info);
417 }
418 
419 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
420  const WorkloadInfo& info) const
421 {
422  return std::make_unique<RefLstmWorkload>(descriptor, info);
423 }
424 
425 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
426  const WorkloadInfo& info) const
427 {
428  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
429  {
430  return std::make_unique<RefMaximumWorkload<int32_t>>(descriptor, info);
431  }
432  else
433  {
434  return std::make_unique<RefMaximumWorkload<float>>(descriptor, info);
435  }
436 }
437 
438 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
439  const WorkloadInfo& info) const
440 {
441  return std::make_unique<RefMeanWorkload>(descriptor, info);
442 }
443 
444 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
445  const WorkloadInfo& info) const
446 {
447  if (descriptor.m_Inputs.empty())
448  {
449  throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
450  }
451  return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
452 }
453 
454 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
455  const WorkloadInfo& info) const
456 {
457  if (descriptor.m_Inputs.empty())
458  {
459  throw InvalidArgumentException("RefWorkloadFactory: CreateMemImport() expected an input tensor.");
460  }
461  return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
462 }
463 
464 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
465  const WorkloadInfo& info) const
466 {
467  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
468  {
469  return std::make_unique<RefMinimumWorkload<int32_t>>(descriptor, info);
470  }
471  else
472  {
473  return std::make_unique<RefMinimumWorkload<float>>(descriptor, info);
474  }
475 }
476 
477 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
478  const WorkloadInfo& info) const
479 {
480  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
481  {
482  return std::make_unique<RefMultiplicationWorkload<int32_t>>(descriptor, info);
483  }
484  else
485  {
486  return std::make_unique<RefMultiplicationWorkload<float>>(descriptor, info);
487  }
488 }
489 
490 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
491  const WorkloadInfo& info) const
492 {
493  return std::make_unique<RefNormalizationWorkload>(descriptor, info);
494 }
495 
496 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
497  const WorkloadInfo& info) const
498 {
499  if (info.m_InputTensorInfos.empty() )
500  {
501  throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
502  }
503  if (info.m_OutputTensorInfos.empty())
504  {
505  throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
506  }
507  if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
508  {
509  throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
510  }
511 
512  return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
513 }
514 
515 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
516  const WorkloadInfo& info) const
517 {
518  return std::make_unique<RefPadWorkload>(descriptor, info);
519 }
520 
521 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
522  const WorkloadInfo& info) const
523 {
524  if (IsQSymmS16(info))
525  {
526  return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
527  }
528  else if (IsBFloat16(info))
529  {
530  return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info);
531  }
532  else if (IsQAsymmS8(info))
533  {
534  return std::make_unique<RefPermuteQAsymmS8Workload>(descriptor, info);
535  }
537  NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
538 }
539 
540 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
541  const WorkloadInfo& info) const
542 {
543  return std::make_unique<RefPooling2dWorkload>(descriptor, info);
544 }
545 
546 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
547  const WorkloadInfo& /*info*/) const
548 {
549  return nullptr;
550 }
551 
552 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePrelu(const PreluQueueDescriptor& descriptor,
553  const WorkloadInfo& info) const
554 {
555  return std::make_unique<RefPreluWorkload>(descriptor, info);
556 }
557 
558 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor,
559  const WorkloadInfo& info) const
560 {
561  return std::make_unique<RefQLstmWorkload>(descriptor, info);
562 }
563 
564 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
565  const WorkloadInfo& info) const
566 {
567  return std::make_unique<RefQuantizeWorkload>(descriptor, info);
568 }
569 
570 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor,
571  const WorkloadInfo& info) const
572 {
573  return std::make_unique<RefRankWorkload>(descriptor, info);
574 }
575 
576 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& descriptor,
577  const WorkloadInfo& info) const
578 {
579  return std::make_unique<RefReduceWorkload>(descriptor, info);
580 }
581 
582 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
583  const WorkloadInfo& info) const
584 {
585  return std::make_unique<RefReshapeWorkload>(descriptor, info);
586 }
587 
588 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
589  const WorkloadInfo& info) const
590 {
591  return std::make_unique<RefResizeWorkload>(descriptor, info);
592 }
593 
594 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateShape(const ShapeQueueDescriptor& descriptor,
595  const WorkloadInfo& info) const
596 {
597  return std::make_unique<RefShapeWorkload>(descriptor, info);
598 }
599 
600 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
601  const WorkloadInfo& info) const
602 {
603  return std::make_unique<RefSliceWorkload>(descriptor, info);
604 }
605 
606 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
607  const WorkloadInfo& info) const
608 {
609  return std::make_unique<RefSoftmaxWorkload>(descriptor, info);
610 }
611 
612 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
613  const WorkloadInfo& info) const
614 {
615  return std::make_unique<RefSpaceToBatchNdWorkload>(descriptor, info);
616 }
617 
618 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
619  const WorkloadInfo& info) const
620 {
621  return std::make_unique<RefSpaceToDepthWorkload>(descriptor, info);
622 }
623 
624 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
625  const WorkloadInfo& info) const
626 {
627  return std::make_unique<RefSplitterWorkload>(descriptor, info);
628 }
629 
630 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
631  const WorkloadInfo& info) const
632 {
633  return std::make_unique<RefStackWorkload>(descriptor, info);
634 }
635 
636 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
637  const WorkloadInfo& info) const
638 {
639  return std::make_unique<RefStridedSliceWorkload>(descriptor, info);
640 }
641 
642 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
643  const WorkloadInfo& info) const
644 {
645  if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
646  {
647  return std::make_unique<RefSubtractionWorkload<int32_t>>(descriptor, info);
648  }
649  else
650  {
651  return std::make_unique<RefSubtractionWorkload<float>>(descriptor, info);
652  }
653 }
654 
655 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
656  const WorkloadInfo& info) const
657 {
658  if (IsQSymmS16(info))
659  {
660  return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info);
661  }
662  else if (IsBFloat16(info))
663  {
664  return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info);
665  }
666  else if (IsQAsymmS8(info))
667  {
668  return std::make_unique<RefTransposeQAsymmS8Workload>(descriptor, info);
669  }
671  NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
672 }
673 
675  const TransposeConvolution2dQueueDescriptor& descriptor,
676  const WorkloadInfo& info) const
677 {
678  return std::make_unique<RefTransposeConvolution2dWorkload>(descriptor, info);
679 }
680 
683  const WorkloadInfo& info) const
684 {
685  return std::make_unique<RefUnidirectionalSequenceLstmWorkload>(descriptor, info);;
686 }
687 
688 } // namespace armnn
std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateNormalization(const NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateArgMinMax(const ArgMinMaxQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateReshape(const ReshapeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateResize(const ResizeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateFakeQuantization(const FakeQuantizationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateConvolution2d(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const override
UnaryOperation m_Operation
Specifies the elementwiseUnary operation to execute.
std::unique_ptr< IWorkload > CreateConstant(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info) const override
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:61
std::unique_ptr< IWorkload > CreateInput(const InputQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateMaximum(const MaximumQueueDescriptor &descriptor, const WorkloadInfo &info) const override
DataLayout
Definition: Types.hpp:49
std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateQuantize(const QuantizeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateUnidirectionalSequenceLstm(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const override
constexpr bool IsQuantizedType()
Definition: TypesUtils.hpp:280
std::unique_ptr< IWorkload > CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateStridedSlice(const StridedSliceQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateSoftmax(const SoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const override
constexpr const char * RefBackendId()
std::vector< BackendOptions > ModelOptions
std::unique_ptr< IWorkload > CreateMultiplication(const MultiplicationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateSpaceToDepth(const SpaceToDepthQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateChannelShuffle(const ChannelShuffleQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreatePad(const PadQueueDescriptor &descriptor, const WorkloadInfo &info) const override
RefPermuteWorkload< DataType::Float16 > RefPermuteFloat16Workload
RefTransposeWorkload< DataType::Float16 > RefTransposeFloat16Workload
std::unique_ptr< IWorkload > CreateComparison(const ComparisonQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor &descriptor, const WorkloadInfo &info) const override
bool IsQAsymmS8(const WorkloadInfo &info)
std::unique_ptr< IWorkload > CreateLogSoftmax(const LogSoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const override
RefPermuteWorkload< DataType::Float32 > RefPermuteFloat32Workload
std::unique_ptr< IWorkload > CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const override
Copyright (c) 2021 ARM Limited and Contributors.
std::unique_ptr< IWorkload > CreateRank(const RankQueueDescriptor &descriptor, const WorkloadInfo &info) const override
void IgnoreUnused(Ts &&...)
bool IsQAsymmU8(const WorkloadInfo &info)
bool IsQSymmS8(const WorkloadInfo &info)
bool IsDataType(const WorkloadInfo &info)
bool IsBFloat16(const WorkloadInfo &info)
std::unique_ptr< IWorkload > CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateFill(const FillQueueDescriptor &descriptor, const WorkloadInfo &info) const override
const BackendId & GetBackendId() const override
std::unique_ptr< IWorkload > CreateBatchNormalization(const BatchNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateL2Normalization(const L2NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateConcat(const ConcatQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::vector< TensorInfo > m_InputTensorInfos
bool IsFloat16(const WorkloadInfo &info)
std::unique_ptr< IWorkload > CreateSubtraction(const SubtractionQueueDescriptor &descriptor, const WorkloadInfo &info) const override
RefTransposeWorkload< DataType::Float32 > RefTransposeFloat32Workload
std::unique_ptr< IWorkload > CreateReduce(const ReduceQueueDescriptor &descriptor, const WorkloadInfo &info) const override
RefTransposeWorkload< DataType::QAsymmU8 > RefTransposeQAsymm8Workload
static bool IsLayerSupported(const Layer &layer, Optional< DataType > dataType, std::string &outReasonIfUnsupported)
std::unique_ptr< IWorkload > CreateDebug(const DebugQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateAddition(const AdditionQueueDescriptor &descriptor, const WorkloadInfo &info) const override
bool IsSigned32(const WorkloadInfo &info)
std::vector< TensorInfo > m_OutputTensorInfos
static bool IsLayerSupported(const BackendId &backendId, const IConnectableLayer &layer, Optional< DataType > dataType, std::string &outReasonIfUnsupported)
std::unique_ptr< IWorkload > CreatePooling2d(const Pooling2dQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreatePrelu(const PreluQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateSplitter(const SplitterQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateDequantize(const DequantizeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateFloor(const FloorQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateConvolution3d(const Convolution3dQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateSlice(const SliceQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateFullyConnected(const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateMean(const MeanQueueDescriptor &descriptor, const WorkloadInfo &Info) const override
std::unique_ptr< IWorkload > CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateStack(const StackQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateDivision(const DivisionQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateShape(const ShapeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo, const bool IsMemoryManaged=true) const override
std::unique_ptr< IWorkload > CreateLogicalBinary(const LogicalBinaryQueueDescriptor &descriptor, const WorkloadInfo &info) const override
Contains information about TensorInfos of a layer.
std::unique_ptr< IWorkload > CreatePermute(const PermuteQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateActivation(const ActivationQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateCast(const CastQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::vector< ITensorHandle * > m_Inputs
RefPermuteWorkload< DataType::QAsymmU8 > RefPermuteQAsymm8Workload
std::unique_ptr< IWorkload > CreateMinimum(const MinimumQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateOutput(const OutputQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreatePreCompiled(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateTranspose(const TransposeQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor &descriptor, const WorkloadInfo &info) const override
bool IsQSymmS16(const WorkloadInfo &info)
std::unique_ptr< IWorkload > CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const override
Depthwise Convolution 2D layer workload data.
std::unique_ptr< IWorkload > CreateQLstm(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateMemImport(const MemImportQueueDescriptor &descriptor, const WorkloadInfo &info) const override
std::unique_ptr< IWorkload > CreateDepthToSpace(const DepthToSpaceQueueDescriptor &descriptor, const WorkloadInfo &info) const override