29 using LayerList = std::list<Layer*>;
30 using Iterator = LayerList::const_iterator;
32 const TensorInfo OverrideDataType(
const TensorInfo& info, Optional<DataType> type)
39 return TensorInfo(info.GetShape(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
48 bool IWorkloadFactory::IsLayerConfigurationSupported(
const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
54 Optional<std::string&> reason = outReasonIfUnsupported;
56 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
59 if (!backendRegistry.IsBackendRegistered(backendId))
62 ss << connectableLayer.GetName() <<
" is not supported on " << backendId
63 <<
" because this backend is not registered.";
65 outReasonIfUnsupported = ss.str();
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
71 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
73 switch(layer.GetType())
77 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
78 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
79 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
80 result = layerSupportObject.IsActivationSupported(
81 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
92 result = layerSupportObject.IsAdditionSupported(
93 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
106 result = layerSupportObject.IsArgMinMaxSupported(
107 OverrideDataType(input, dataType),
115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
122 result = layerSupportObject.IsBatchNormalizationSupported(
123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
157 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
162 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
164 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
172 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
174 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
176 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
178 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
179 OverrideDataType(input1, dataType),
181 cLayer->GetParameters(),
187 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
188 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
193 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
194 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
195 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
200 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
202 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
209 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
214 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
216 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
221 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
223 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
225 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
228 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
231 Optional<TensorInfo> biases;
232 if (descriptor.m_BiasEnabled)
237 result = layerSupportObject.IsConvolution2dSupported(
241 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
248 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
250 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
252 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
255 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
256 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
259 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
262 Optional<TensorInfo> biases;
263 if (descriptor.m_BiasEnabled)
265 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
269 result = layerSupportObject.IsConvolution3dSupported(
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
283 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
284 OverrideDataType(output, dataType),
290 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
292 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
293 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
295 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
296 OverrideDataType(output, dataType),
297 cLayer->GetParameters(),
303 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
304 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
306 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
309 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
312 Optional<TensorInfo> biases;
313 if (descriptor.m_BiasEnabled)
319 result = layerSupportObject.IsDepthwiseConvolutionSupported(
323 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
330 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
331 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
333 result = layerSupportObject.IsDequantizeSupported(input,
334 OverrideDataType(output, dataType),
340 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
341 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
342 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
343 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
345 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
346 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
347 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
348 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
350 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
351 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
364 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
366 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
367 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
369 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
370 OverrideDataType(output, dataType),
371 cLayer->GetParameters(),
377 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
378 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
379 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
380 const FillDescriptor& descriptor = cLayer->GetParameters();
382 result = layerSupportObject.IsFillSupported(
383 OverrideDataType(input, dataType),
384 OverrideDataType(output, dataType),
391 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
392 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
393 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
394 cLayer->GetParameters(),
400 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
401 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
402 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
403 OverrideDataType(output, dataType),
409 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
410 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
411 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
413 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
414 TensorInfo weightsInfo;
415 const TensorInfo* weightsInfoPtr =
nullptr;
417 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
418 weightsInfoPtr = &weightsInfo;
421 const TensorInfo* biasInfoPtr =
nullptr;
422 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}),
DataType::BFloat16);
423 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}),
DataType::Float16);
424 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}),
DataType::Float32);
427 if (descriptor.m_BiasEnabled)
429 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
430 biasInfoPtr = &biasInfo;
435 switch(input.GetDataType())
439 biasInfoPtr = &dummyBFloat16Bias;
444 biasInfoPtr = &dummyFloat16Bias;
449 biasInfoPtr = &dummyFloat32Bias;
457 biasInfoPtr = &dummyQA8Bias;
466 result = layerSupportObject.IsFullyConnectedSupported(
467 OverrideDataType(input, dataType),
468 OverrideDataType(output, dataType),
477 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
478 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
479 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
480 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
481 const GatherDescriptor& descriptor = cLayer->GetParameters();
482 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
484 OverrideDataType(output, dataType),
491 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
492 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
497 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
498 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
500 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
501 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
503 result = layerSupportObject.IsInstanceNormalizationSupported(
504 OverrideDataType(input, dataType),
505 OverrideDataType(output, dataType),
512 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
513 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
515 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
516 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
518 result = layerSupportObject.IsL2NormalizationSupported(
519 OverrideDataType(input, dataType),
520 OverrideDataType(output, dataType),
527 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
529 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
531 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
533 result = layerSupportObject.IsLogicalBinarySupported(input0,
536 cLayer->GetParameters(),
542 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
544 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
545 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
547 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
548 OverrideDataType(output, dataType),
549 cLayer->GetParameters(),
555 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
556 const LstmDescriptor& descriptor = cLayer->GetParameters();
559 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
561 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
563 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
566 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
567 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
568 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
569 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
572 const TensorInfo& inputToForgetWeights
573 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
574 const TensorInfo& inputToCellWeights
575 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
576 const TensorInfo& inputToOutputWeights
577 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
578 const TensorInfo& recurrentToForgetWeights
579 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
580 const TensorInfo& recurrentToCellWeights
581 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
582 const TensorInfo& recurrentToOutputWeights
583 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
584 const TensorInfo& forgetGateBias
585 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
586 const TensorInfo& cellBias
587 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
588 const TensorInfo& outputGateBias
589 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
591 LstmInputParamsInfo paramsInfo;
593 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
594 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
595 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
596 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
597 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
598 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
599 paramsInfo.m_ForgetGateBias = &forgetGateBias;
600 paramsInfo.m_CellBias = &cellBias;
601 paramsInfo.m_OutputGateBias = &outputGateBias;
605 TensorInfo optInputToInputWeights;
606 TensorInfo optRecurrentToInputWeights;
607 TensorInfo optCellToInputWeights;
608 TensorInfo optInputGateBias;
609 TensorInfo optProjectionWeights;
610 TensorInfo optProjectionBias;
611 TensorInfo optCellToForgetWeights;
612 TensorInfo optCellToOutputWeights;
613 TensorInfo optInputLayerNormWeights;
614 TensorInfo optForgetLayerNormWeights;
615 TensorInfo optCellLayerNormWeights;
616 TensorInfo optOutputLayerNormWeights;
618 if(!descriptor.m_CifgEnabled)
620 optInputToInputWeights =
621 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
622 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
624 optRecurrentToInputWeights =
625 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
626 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
628 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
629 paramsInfo.m_InputGateBias = &optInputGateBias;
632 if(descriptor.m_ProjectionEnabled)
634 optProjectionWeights =
635 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
636 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
637 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
640 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
641 paramsInfo.m_ProjectionBias = &optProjectionBias;
645 if(descriptor.m_PeepholeEnabled)
647 if(!descriptor.m_CifgEnabled)
649 optCellToInputWeights =
650 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
652 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
654 optCellToForgetWeights =
655 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
656 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
657 optCellToOutputWeights =
658 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
659 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
662 if(descriptor.m_LayerNormEnabled)
664 if (!descriptor.m_CifgEnabled)
666 optInputLayerNormWeights = OverrideDataType(
667 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
668 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
671 optForgetLayerNormWeights = OverrideDataType(
672 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
673 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
675 optCellLayerNormWeights = OverrideDataType(
676 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
677 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
679 optOutputLayerNormWeights = OverrideDataType(
680 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
681 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
684 result = layerSupportObject.IsLstmSupported(
699 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
700 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
701 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
703 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
704 OverrideDataType(input1, dataType),
705 OverrideDataType(output, dataType),
711 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
712 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
714 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
715 OverrideDataType(output, dataType),
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
724 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
725 OverrideDataType(output, dataType),
731 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
735 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
736 OverrideDataType(input1, dataType),
737 OverrideDataType(output, dataType),
743 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
746 auto getTensorInfo = [&dataType](
const InputSlot& slot)
748 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
753 std::vector<TensorInfo> inputs(beginI, endI);
755 auto getTensorInfoPtr = [](
const TensorInfo&
info)
762 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
764 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
766 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
773 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
774 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
775 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
776 result = layerSupportObject.IsMultiplicationSupported(
777 OverrideDataType(input0, dataType),
778 OverrideDataType(input1, dataType),
779 OverrideDataType(output, dataType),
785 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
786 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
787 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
788 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
789 OverrideDataType(output, dataType),
790 cLayer->GetParameters(),
796 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
797 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
802 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
803 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
804 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
805 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
806 OverrideDataType(output, dataType),
807 cLayer->GetParameters(),
813 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
814 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
815 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
816 result = layerSupportObject.IsPadSupported(
817 OverrideDataType(input, dataType),
818 OverrideDataType(output, dataType),
819 cLayer->GetParameters(),
825 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
826 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
827 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
828 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
829 OverrideDataType(output, dataType),
830 cLayer->GetParameters(),
836 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
837 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
838 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
839 cLayer->GetParameters(),
845 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
846 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
847 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
852 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
853 const QLstmDescriptor& descriptor = cLayer->GetParameters();
856 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
857 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
858 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
861 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
862 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
863 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
866 LstmInputParamsInfo paramsInfo;
869 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() !=
nullptr);
870 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() !=
nullptr);
871 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() !=
nullptr);
872 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
873 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
874 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
876 paramsInfo.m_RecurrentToForgetWeights =
877 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
878 paramsInfo.m_RecurrentToCellWeights =
879 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
880 paramsInfo.m_RecurrentToOutputWeights =
881 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
883 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
884 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
885 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
887 if(!descriptor.m_CifgEnabled)
889 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
890 paramsInfo.m_RecurrentToInputWeights =
891 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
892 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
895 if(descriptor.m_ProjectionEnabled)
897 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
900 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
902 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
906 if(descriptor.m_PeepholeEnabled)
908 if (!descriptor.m_CifgEnabled)
910 paramsInfo.m_CellToInputWeights =
911 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
914 paramsInfo.m_CellToForgetWeights =
915 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
916 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
919 if(descriptor.m_LayerNormEnabled)
921 if (!descriptor.m_CifgEnabled)
923 paramsInfo.m_InputLayerNormWeights =
924 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
927 paramsInfo.m_ForgetLayerNormWeights =
928 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
929 paramsInfo.m_CellLayerNormWeights =
930 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
931 paramsInfo.m_OutputLayerNormWeights =
932 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
935 result = layerSupportObject.IsQLstmSupported(input,
948 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
951 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
952 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
953 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
956 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
957 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
960 QuantizedLstmInputParamsInfo paramsInfo;
962 paramsInfo.m_InputToInputWeights =
963 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
964 paramsInfo.m_InputToForgetWeights =
965 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
966 paramsInfo.m_InputToCellWeights =
967 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
968 paramsInfo.m_InputToOutputWeights =
969 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
971 paramsInfo.m_RecurrentToInputWeights =
972 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
973 paramsInfo.m_RecurrentToForgetWeights =
974 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
975 paramsInfo.m_RecurrentToCellWeights =
976 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
977 paramsInfo.m_RecurrentToOutputWeights =
978 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
980 paramsInfo.m_InputGateBias =
981 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
982 paramsInfo.m_ForgetGateBias =
983 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
984 paramsInfo.m_CellBias =
985 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
986 paramsInfo.m_OutputGateBias =
987 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
989 result = layerSupportObject.IsQuantizedLstmSupported(input,
1000 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1001 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1002 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1003 result = layerSupportObject.IsDivisionSupported(
1004 OverrideDataType(input0, dataType),
1005 OverrideDataType(input1, dataType),
1006 OverrideDataType(output, dataType),
1012 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1013 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1014 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1015 OverrideDataType(output, dataType),
1021 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
1022 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1023 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1024 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1025 OverrideDataType(output, dataType),
1026 cLayer->GetParameters(),
1032 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
1033 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1034 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1035 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1036 OverrideDataType(output, dataType),
1037 cLayer->GetParameters(),
1043 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1044 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1046 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1047 OverrideDataType(output, dataType),
1053 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
1055 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1056 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1058 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1059 OverrideDataType(output, dataType),
1060 cLayer->GetParameters(),
1066 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
1067 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1068 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1069 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1070 OverrideDataType(output, dataType),
1071 cLayer->GetParameters(),
1077 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
1078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1080 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1081 OverrideDataType(output, dataType),
1082 cLayer->GetParameters(),
1088 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
1090 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1091 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1093 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1094 OverrideDataType(output, dataType),
1095 cLayer->GetParameters(),
1101 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
1102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1105 auto getTensorInfo = [&dataType](
const OutputSlot& slot)
1107 return OverrideDataType(slot.GetTensorInfo(), dataType);
1111 std::vector<TensorInfo> outputs(beginI, endI);
1113 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1115 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1117 cLayer->GetParameters(),
1123 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
1126 auto getTensorInfo = [&dataType](
const InputSlot& slot)
1128 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1132 std::vector<TensorInfo> inputs(beginI, endI);
1134 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1140 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1142 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1144 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1150 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
1153 auto getTensorInfoIn = [&dataType](
const InputSlot& slot)
1155 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1157 auto getTensorInfoOut = [&dataType](
const OutputSlot& slot)
1159 return OverrideDataType(slot.GetTensorInfo(), dataType);
1163 std::vector<TensorInfo> inputs(beginI, endI);
1167 std::vector<TensorInfo> outputs(beginO, endO);
1170 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1176 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1180 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1183 result = layerSupportObject.IsStandInSupported(inputPtrs,
1185 cLayer->GetParameters(),
1191 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
1192 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1193 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1194 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1195 OverrideDataType(output, dataType),
1196 cLayer->GetParameters(),
1202 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1203 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1204 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1205 result = layerSupportObject.IsSubtractionSupported(
1206 OverrideDataType(input0, dataType),
1207 OverrideDataType(input1, dataType),
1208 OverrideDataType(output, dataType),
1214 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1215 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1216 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1217 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1218 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1219 OverrideDataType(input1, dataType),
1220 OverrideDataType(output0, dataType),
1221 OverrideDataType(output1, dataType),
1227 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
1228 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1229 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1230 result = layerSupportObject.IsMeanSupported(
1231 OverrideDataType(input, dataType),
1232 OverrideDataType(output, dataType),
1233 cLayer->GetParameters(),
1239 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1240 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1241 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1242 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1243 OverrideDataType(input1, dataType),
1244 OverrideDataType(output, dataType),
1250 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1251 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1252 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1253 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1254 OverrideDataType(alpha, dataType),
1255 OverrideDataType(output, dataType),
1261 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
1262 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1263 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1264 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1265 OverrideDataType(output, dataType),
1266 cLayer->GetParameters(),
1272 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
1274 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1276 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1278 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1280 Optional<TensorInfo> biases;
1281 if (descriptor.m_BiasEnabled)
1284 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1289 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1291 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1302 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1303 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1304 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1306 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1307 OverrideDataType(output, dataType),
1308 cLayer->GetParameters(),
1314 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1318 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1320 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1322 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1325 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1328 const TensorInfo& inputToForgetWeights
1329 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1330 const TensorInfo& inputToCellWeights
1331 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1332 const TensorInfo& inputToOutputWeights
1333 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1334 const TensorInfo& recurrentToForgetWeights
1335 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1336 const TensorInfo& recurrentToCellWeights
1337 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1338 const TensorInfo& recurrentToOutputWeights
1339 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1340 const TensorInfo& forgetGateBias
1341 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1342 const TensorInfo& cellBias
1343 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1344 const TensorInfo& outputGateBias
1345 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1347 LstmInputParamsInfo paramsInfo;
1349 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1350 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1351 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1352 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1353 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1354 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1355 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1356 paramsInfo.m_CellBias = &cellBias;
1357 paramsInfo.m_OutputGateBias = &outputGateBias;
1360 TensorInfo optInputToInputWeights;
1361 TensorInfo optRecurrentToInputWeights;
1362 TensorInfo optCellToInputWeights;
1363 TensorInfo optInputGateBias;
1364 TensorInfo optProjectionWeights;
1365 TensorInfo optProjectionBias;
1366 TensorInfo optCellToForgetWeights;
1367 TensorInfo optCellToOutputWeights;
1368 TensorInfo optInputLayerNormWeights;
1369 TensorInfo optForgetLayerNormWeights;
1370 TensorInfo optCellLayerNormWeights;
1371 TensorInfo optOutputLayerNormWeights;
1373 if(!descriptor.m_CifgEnabled)
1375 optInputToInputWeights =
1376 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1377 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1379 optRecurrentToInputWeights =
1380 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1381 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1383 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1384 paramsInfo.m_InputGateBias = &optInputGateBias;
1387 if(descriptor.m_ProjectionEnabled)
1389 optProjectionWeights =
1390 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1391 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1392 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
1395 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1396 paramsInfo.m_ProjectionBias = &optProjectionBias;
1400 if(descriptor.m_PeepholeEnabled)
1402 if(!descriptor.m_CifgEnabled)
1404 optCellToInputWeights =
1405 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1407 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1409 optCellToForgetWeights =
1410 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1411 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1412 optCellToOutputWeights =
1413 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1414 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1417 if(descriptor.m_LayerNormEnabled)
1419 if (!descriptor.m_CifgEnabled)
1421 optInputLayerNormWeights = OverrideDataType(
1422 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1423 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1426 optForgetLayerNormWeights = OverrideDataType(
1427 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1428 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1430 optCellLayerNormWeights = OverrideDataType(
1431 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1432 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1434 optOutputLayerNormWeights = OverrideDataType(
1435 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1436 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1439 Optional<TensorInfo> hiddenStateOut;
1440 Optional<TensorInfo> cellStateOut;
1442 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1455 ARMNN_ASSERT_MSG(
false,
"WorkloadFactory did not recognise type of layer.");
1456 reason.value() =
"Unrecognised layer type";
1467 std::string& outReasonIfUnsupported)
1469 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1474 std::string& outReasonIfUnsupported)
1476 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1477 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1483 std::string& outReasonIfUnsupported,
1486 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1487 return IsLayerConfigurationSupported(layer->GetBackendId(),
1490 outReasonIfUnsupported,
1497 std::string& outReasonIfUnsupported,
1500 return IsLayerConfigurationSupported(backendId,
1503 outReasonIfUnsupported,
1510 return std::unique_ptr<IWorkload>();
1516 return std::unique_ptr<IWorkload>();
1522 return std::unique_ptr<IWorkload>();
1528 return std::unique_ptr<IWorkload>();
1534 return std::unique_ptr<IWorkload>();
1540 return std::unique_ptr<IWorkload>();
1546 return std::unique_ptr<IWorkload>();
1552 return std::unique_ptr<IWorkload>();
1558 return std::unique_ptr<IWorkload>();
1564 return std::unique_ptr<IWorkload>();
1570 return std::unique_ptr<IWorkload>();
1576 return std::unique_ptr<IWorkload>();
1582 return std::unique_ptr<IWorkload>();
1588 return std::unique_ptr<IWorkload>();
1594 return std::unique_ptr<IWorkload>();
1600 return std::unique_ptr<IWorkload>();
1606 return std::unique_ptr<IWorkload>();
1612 return std::unique_ptr<IWorkload>();
1618 return std::unique_ptr<IWorkload>();
1624 return std::unique_ptr<IWorkload>();
1630 return std::unique_ptr<IWorkload>();
1636 return std::unique_ptr<IWorkload>();
1642 return std::unique_ptr<IWorkload>();
1648 return std::unique_ptr<IWorkload>();
1654 return std::unique_ptr<IWorkload>();
1660 return std::unique_ptr<IWorkload>();
1666 return std::unique_ptr<IWorkload>();
1672 return std::unique_ptr<IWorkload>();
1679 return std::unique_ptr<IWorkload>();
1685 return std::unique_ptr<IWorkload>();
1691 return std::unique_ptr<IWorkload>();
1697 return std::unique_ptr<IWorkload>();
1703 return std::unique_ptr<IWorkload>();
1709 return std::unique_ptr<IWorkload>();
1715 return std::unique_ptr<IWorkload>();
1721 return std::unique_ptr<IWorkload>();
1727 return std::unique_ptr<IWorkload>();
1733 return std::unique_ptr<IWorkload>();
1739 return std::unique_ptr<IWorkload>();
1745 return std::unique_ptr<IWorkload>();
1751 return std::unique_ptr<IWorkload>();
1757 return std::unique_ptr<IWorkload>();
1763 return std::unique_ptr<IWorkload>();
1769 return std::unique_ptr<IWorkload>();
1775 return std::unique_ptr<IWorkload>();
1781 return std::unique_ptr<IWorkload>();
1787 return std::unique_ptr<IWorkload>();
1793 return std::unique_ptr<IWorkload>();
1799 return std::unique_ptr<IWorkload>();
1805 return std::unique_ptr<IWorkload>();
1811 return std::unique_ptr<IWorkload>();
1816 return std::unique_ptr<IWorkload>();
1822 return std::unique_ptr<IWorkload>();
1828 return std::unique_ptr<IWorkload>();
1834 return std::unique_ptr<IWorkload>();
1840 return std::unique_ptr<IWorkload>();
1846 return std::unique_ptr<IWorkload>();
1852 return std::unique_ptr<IWorkload>();
1858 return std::unique_ptr<IWorkload>();
1864 return std::unique_ptr<IWorkload>();
1870 return std::unique_ptr<IWorkload>();
1876 return std::unique_ptr<IWorkload>();
1882 return std::unique_ptr<IWorkload>();
1888 return std::unique_ptr<IWorkload>();
1894 return std::unique_ptr<IWorkload>();
1900 return std::unique_ptr<IWorkload>();
1907 return std::unique_ptr<IWorkload>();
1914 return std::unique_ptr<IWorkload>();
virtual std::unique_ptr< IWorkload > CreateSplitter(const SplitterQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateBatchNormalization(const BatchNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDebug(const DebugQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateL2Normalization(const L2NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
virtual std::unique_ptr< IWorkload > CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateMultiplication(const MultiplicationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateArgMinMax(const ArgMinMaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateLogSoftmax(const LogSoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
std::vector< BackendOptions > ModelOptions
virtual std::unique_ptr< IWorkload > CreateStridedSlice(const StridedSliceQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateStack(const StackQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
constexpr TransformIterator< Function, Iterator > MakeTransformIterator(Iterator i, Function f)
virtual std::unique_ptr< IWorkload > CreateFakeQuantization(const FakeQuantizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQuantizedLstm(const QuantizedLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQLstm(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConstant(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info) const
BackendRegistry & BackendRegistryInstance()
virtual std::unique_ptr< IWorkload > CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
Copyright (c) 2021 ARM Limited and Contributors.
virtual std::unique_ptr< IWorkload > CreateActivation(const ActivationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTranspose(const TransposeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDivision(const DivisionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMaximum(const MaximumQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConcat(const ConcatQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateUnidirectionalSequenceLstm(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMerge(const MergeQueueDescriptor &descriptor, const WorkloadInfo &info) const
armnn::Optional< armnn::DataType > GetBiasTypeFromWeightsType(armnn::Optional< armnn::DataType > weightsType)
virtual std::unique_ptr< IWorkload > CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateRank(const RankQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateResize(const ResizeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateCast(const CastQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual std::unique_ptr< IWorkload > CreateQuantize(const QuantizeQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReduce(const ReduceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSwitch(const SwitchQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreatePad(const PadQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT(COND)
LstmDescriptor UnidirectionalSequenceLstmDescriptor
static bool IsLayerSupported(const BackendId &backendId, const IConnectableLayer &layer, Optional< DataType > dataType, std::string &outReasonIfUnsupported)
virtual std::unique_ptr< IWorkload > CreateNormalization(const NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalBinary(const LogicalBinaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReshape(const ReshapeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePermute(const PermuteQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFill(const FillQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateComparison(const ComparisonQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateConvolution3d(const Convolution3dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePooling2d(const Pooling2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToDepth(const SpaceToDepthQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateChannelShuffle(const ChannelShuffleQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMinimum(const MinimumQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDepthToSpace(const DepthToSpaceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSlice(const SliceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateAddition(const AdditionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMean(const MeanQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateOutput(const OutputQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSoftmax(const SoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
Contains information about TensorInfos of a layer.
virtual std::unique_ptr< IWorkload > CreateFullyConnected(const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFloor(const FloorQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemImport(const MemImportQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSubtraction(const SubtractionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePreCompiled(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateShape(const ShapeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
Depthwise Convolution 2D layer workload data.
virtual std::unique_ptr< IWorkload > CreateConvolution2d(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePrelu(const PreluQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDequantize(const DequantizeQueueDescriptor &descriptor, const WorkloadInfo &info) const