26 template<
typename Float32Func,
typename Uint8Func,
typename ... Params>
27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
35 &FalseFunc<Params...>,
38 &FalseFunc<Params...>,
39 &FalseFunc<Params...>,
40 std::forward<Params>(params)...);
48 std::string CreateIncorrectDimensionsErrorMsg(
unsigned int expected,
50 std::string& layerStr,
51 std::string& tensorName)
53 std::string errorMsg =
"Reference " + layerStr +
": Expected " + std::to_string(expected) +
" dimensions but got" +
54 " " + std::to_string(actual) +
" dimensions instead, for the '" + tensorName +
"' tensor.";
62 const std::vector<TensorInfo>& infos,
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
108 inputInfos.push_back(&infos[i]);
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&
descriptor)),
127 if (infos.size() != 4)
130 "TensorInfos should be of format: {input, output, weights, biases}.");
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&
descriptor));
141 reasonIfUnsupported);
150 reasonIfUnsupported);
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
160 if (infos.size() != 4)
163 "TensorInfos should be of format: {input, output, weights, biases}.");
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&
descriptor));
174 reasonIfUnsupported);
183 reasonIfUnsupported);
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
219 reasonIfUnsupported);
225 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
227 reasonIfUnsupported);
231 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
232 reasonIfUnsupported);
237 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
242 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
243 reasonIfUnsupported);
252 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
253 lstmParamsInfo.
value(),
262 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
263 lstmParamsInfo.
value(),
270 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
271 reasonIfUnsupported);
279 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
286 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
291 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
296 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
297 reasonIfUnsupported);
305 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
310 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
315 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
320 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
325 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
330 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
335 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
336 reasonIfUnsupported);
339 std::vector<TensorInfo> outputInfos;
340 for (uint32_t i = 1; i < infos.size(); i++)
342 outputInfos.push_back(infos[i]);
345 {outputInfos.begin(), outputInfos.end()},
346 *(PolymorphicDowncast<const ViewsDescriptor*>(&
descriptor)),
351 std::vector<const TensorInfo*> inputInfos;
352 for (uint32_t i = 0; i < infos.size() - 1; i++)
354 inputInfos.push_back(&infos[i]);
357 infos[infos.size() - 1],
358 *(PolymorphicDowncast<const StackDescriptor*>(&
descriptor)),
364 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
371 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
372 reasonIfUnsupported);
375 if (infos.size() != 4)
378 "TensorInfos should be of format: {input, output, weights, biases}.");
381 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&
descriptor));
389 reasonIfUnsupported);
398 reasonIfUnsupported);
406 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
410 if (infos.size() != 4)
413 "TensorInfos should be of format: {input, output, weights, biases}.");
416 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&
descriptor));
424 reasonIfUnsupported);
433 reasonIfUnsupported);
446 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
448 reasonIfUnsupported);
451 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
452 reasonIfUnsupported);
461 if (infos.size() != 6)
464 "should be of format: {input, outputStateIn, cellStateIn, " 465 "hiddenStateOutputVal, cellStateOutputVal, output}");
467 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&
descriptor));
475 lstmParamsInfo.
value(),
481 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
482 reasonIfUnsupported);
497 quantizedLstmInputParamsInfo.
value(),
511 bool supported =
true;
514 std::array<DataType,6> supportedTypes = {
524 "Reference activation: input type not supported.");
527 "Reference activation: output type not supported.");
530 "Reference activation: input and output types mismatched.");
533 "Reference activation: input and output shapes are of different rank.");
536 struct ActivationFunctionSupported :
public Rule 568 supported &=
CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
569 "Reference activation: function not supported.");
579 bool supported =
true;
581 std::array<DataType,7> supportedTypes = {
592 "Reference addition: input 0 is not a supported type.");
595 "Reference addition: input 1 is not a supported type.");
598 "Reference addition: output is not a supported type.");
601 "Reference addition: input 0 and Input 1 types are mismatched");
604 "Reference addition: input and output types are mismatched");
607 "Reference addition: shapes are not suitable for implicit broadcast.");
618 std::array<DataType, 8> supportedInputTypes =
630 std::array<DataType,2> supportedOutputTypes = {
635 bool supported =
true;
638 "Reference ArgMinMax: input is not a supported type.");
640 "Reference ArgMinMax: output type not supported");
656 std::array<DataType, 6> supportedTypes =
666 bool supported =
true;
669 "Reference batch normalization: input is not a supported type.");
672 "Reference batch normalization: output is not a supported type.");
675 "Reference batch normalization: input and output types are mismatched");
678 "Reference batch normalization: mean is not a supported type.");
681 "Reference batch normalization: variance is not a supported type.");
684 "Reference batch normalization: beta is not a supported type.");
687 "Reference batch normalization: gamma is not a supported type.");
699 bool supported =
true;
701 std::string batchToSpaceNdLayerStr =
"batchToSpaceNd";
702 std::string inputTensorStr =
"input";
703 std::string outputTensorStr =
"output";
706 std::array<DataType,6> supportedTypes =
717 "Reference BatchToSpaceNd: input type not supported.");
720 "Reference BatchToSpaceNd: output type not supported.");
723 "Reference BatchToSpaceNd: input and output types mismatched.");
727 CreateIncorrectDimensionsErrorMsg(4,
729 batchToSpaceNdLayerStr,
730 outputTensorStr).data());
734 CreateIncorrectDimensionsErrorMsg(4,
736 batchToSpaceNdLayerStr,
737 inputTensorStr).data());
746 std::array<DataType, 9> supportedInputTypes =
758 bool supported =
true;
760 "Reference cast: input is not a supported type");
764 "Reference cast: output is not a supported type");
767 "Reference cast: input and output shapes have different number of total elements");
778 bool supported =
true;
781 std::array<DataType, 7> supportedTypes =
793 "Reference ChannelShuffle: input is not a supported type.");
796 "Reference ChannelShuffle: output is not a supported type.");
799 "Reference ChannelShuffle: input and output types are mismatched.");
812 std::array<DataType, 8> supportedInputTypes =
824 bool supported =
true;
826 "Reference comparison: input 0 is not a supported type");
829 "Reference comparison: input 0 and Input 1 types are mismatched");
832 "Reference comparison: output is not of type Boolean");
844 bool supported =
true;
845 std::array<DataType,7> supportedTypes =
857 "Reference concatenation: output type not supported");
862 "Reference concatenation: input type not supported");
865 "Reference concatenation: input and output types mismatched.");
874 std::array<DataType,8> supportedTypes =
887 "Reference constant: output is not a supported type.");
894 bool supported =
true;
897 "Reference for ConvertBf16ToFp32 layer: input type not supported");
900 "Reference for ConvertBf16ToFp32 layer: output type not supported");
912 &FalseInputFuncF32<>,
918 &FalseOutputFuncF16<>,
929 bool supported =
true;
932 "Reference for ConvertFp32ToBf16 layer: input type not supported");
935 "Reference for ConvertFp32ToBf16 layer: output type not supported");
946 &FalseInputFuncF16<>,
954 &FalseOutputFuncF32<>,
967 bool supported =
true;
970 std::array<DataType,7> supportedTypes =
982 "Reference Convolution2d: input is not a supported type.");
985 "Reference Convolution2d: output is not a supported type.");
992 reasonIfUnsupported.
value() +=
"Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
999 "Reference Convolution2d: input and output types mismatched.");
1005 std::array<DataType, 3> supportedWeightTypes =
1013 "Reference Convolution2d: weights type not supported for quantized input.");
1018 "Reference Convolution2d: weights is not a supported type.");
1021 "Reference Convolution2d: input and weights types mismatched.");
1026 std::array<DataType,4> biasesSupportedTypes =
1035 "Reference Convolution2d: biases is not a supported type.");
1049 bool supported =
true;
1052 std::array<DataType,7> supportedTypes =
1064 "Reference Convolution3d: input is not a supported type.");
1067 "Reference Convolution3d: output is not a supported type.");
1070 "Reference Convolution3d: input and output types mismatched.");
1075 std::array<DataType, 3> supportedWeightTypes =
1083 "Reference Convolution3d: weights type not supported for quantized input.");
1088 "Reference Convolution3d: weights is not a supported type.");
1091 "Reference Convolution3d: input and weights types mismatched.");
1096 std::array<DataType,4> biasesSupportedTypes =
1105 "Reference Convolution3d: biases is not a supported type.");
1116 bool supported =
true;
1118 std::array<DataType, 8> supportedTypes =
1131 "Reference for Debug layer: input type not supported");
1134 "Reference for Debug layer: output type not supported");
1137 "Reference for Debug layer: input and output types are mismatched");
1148 bool supported =
true;
1150 std::array<DataType,6> supportedTypes =
1161 "Reference DepthToSpace: input type not supported");
1164 "Reference DepthToSpace: output type not supported");
1167 "Reference DepthToSpace: input and output types are mismatched");
1180 bool supported =
true;
1183 std::array<DataType,7> supportedTypes =
1195 "Reference DepthwiseConvolution2d: input is not a supported type.");
1198 "Reference DepthwiseConvolution2d: output is not a supported type.");
1201 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1206 std::array<DataType, 3> supportedWeightTypes =
1214 "Reference DepthwiseConvolution2d: weights type not supported for " 1215 "quantized input.");
1220 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1223 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1228 std::array<DataType,4> biasesSupportedTypes =
1236 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1247 bool supported =
true;
1249 std::array<DataType,5> supportedInputTypes = {
1258 "Reference for Dequantize layer: input type not supported.");
1261 "Reference for Dequantize layer: per-axis quantized input not supported.");
1263 std::array<DataType,3> supportedOutputTypes = {
1270 "Reference for Dequantize layer: output type not supported.");
1273 "Reference for Dequantize layer: input/output shapes have different num total " 1289 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
1291 bool supported =
true;
1293 std::array<DataType,6> supportedInputTypes =
1304 "Reference DetectionPostProcess: input 0 is not a supported type.");
1307 "Reference DetectionPostProcess: input 1 is not a supported type.");
1327 bool supported =
true;
1329 std::array<DataType,7> supportedTypes = {
1340 "Reference division: input 0 is not a supported type.");
1343 "Reference division: input 1 is not a supported type.");
1346 "Reference division: output is not a supported type.");
1349 "Reference division: input 0 and Input 1 types are mismatched");
1352 "Reference division: input and output types are mismatched");
1355 "Reference division: shapes are not suitable for implicit broadcast.");
1367 std::array<DataType, 7> supportedTypes =
1378 std::array<DataType, 1> logicalSupportedTypes =
1383 bool supported =
true;
1388 "Reference elementwise unary: input type not supported");
1391 "Reference elementwise unary: output type not supported");
1396 "Reference elementwise unary: input type not supported");
1399 "Reference elementwise unary: output type not supported");
1403 "Reference elementwise unary: input and output types not matching");
1406 "Reference elementwise unary: input and output shapes" 1407 "have different number of total elements");
1417 bool supported =
true;
1419 std::array<DataType,1> supportedTypes =
1425 "Reference fake quantization: input type not supported.");
1438 bool supported =
true;
1440 std::array<DataType,3> supportedTypes =
1448 "Reference Fill: input type not supported.");
1451 "Reference Fill: output type not supported.");
1460 bool supported =
true;
1462 std::array<DataType,3> supportedTypes =
1470 "Reference Floor: input type not supported.");
1473 "Reference Floor: output type not supported.");
1485 bool supported =
true;
1488 std::array<DataType,6> supportedTypes =
1499 "Reference Fully Connected: input type not supported.");
1502 "Reference Fully Connected: output type not supported.");
1505 "Reference Fully Connected: weights type not supported.");
1512 reasonIfUnsupported.
value() +=
"Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1519 "Reference Fully Connected: input and output types mismatched.");
1523 "Reference Fully Connected: weights is not a supported type.");
1526 "Reference Fully Connected: input and weights types mismatched.");
1531 std::array<DataType, 5>
1532 supportedBiasTypes =
1542 "Reference Fully Connected: bias type not supported.");
1545 "Reference Fully Connected: bias and weight types mismatch.");
1548 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1551 "Reference Fully Connected: bias must have 1 dimension.");
1563 bool supported =
true;
1564 std::array<DataType,7> supportedTypes =
1576 "Reference GatherNd: input type not supported");
1579 "Reference GatherNd: output type not supported");
1582 "Reference GatherNd: indices (input1) type not supported");
1585 "Reference GatherNd: input and output types not matching");
1596 bool supported =
true;
1597 std::array<DataType,7> supportedTypes =
1608 if (descriptor.
m_Axis != 0)
1610 reasonIfUnsupported.
value() += std::string(
"Reference Gather: axis not supported\n");
1614 "Reference Gather: input type not supported");
1617 "Reference Gather: output type not supported");
1620 "Reference Gather: indices (input1) type not supported");
1623 "Reference Gather: input and output types not matching");
1641 std::array<DataType, 3> supportedTypes =
1648 bool supported =
true;
1651 "Reference Instance Normalization: input type not supported.");
1654 "Reference Instance Normalization: output type not supported.");
1657 "Reference Instance Normalization: input and output types mismatched.");
1660 "Reference Instance Normalization: input and output shapes have different " 1661 "num total elements.");
1673 std::array<DataType, 6> supportedTypes =
1683 bool supported =
true;
1686 "Reference L2normalization: input type not supported.");
1689 "Reference L2normalization: output type not supported.");
1692 "Reference L2normalization: input and output types mismatched.");
1695 "Reference L2normalization: input and output shapes have different " 1696 "num total elements.");
1709 std::array<DataType, 1> supportedTypes =
1714 bool supported =
true;
1716 "Reference LogicalBinary: input 0 type not supported");
1718 "Reference LogicalBinary: input 1 type not supported");
1721 "Reference LogicalBinary: input and output types do not match");
1733 std::array<DataType, 3> supportedTypes =
1740 bool supported =
true;
1742 "Reference LogSoftmax: input type not supported");
1745 "Reference LogSoftmax: output type not supported");
1748 "Reference LogSoftmax: input and output types do not match");
1767 bool supported =
true;
1769 std::array<DataType,3> supportedTypes = {
1777 "Reference Lstm: input is not a supported type.");
1779 "Reference Lstm: input and outputStateIn types are mismatched");
1781 "Reference Lstm: input and cellStateIn types are mismatched");
1783 "Reference Lstm: input and scratchBuffer types are mismatched");
1785 "Reference Lstm: input and outputStateOut types are mismatched");
1787 "Reference Lstm: input and cellStateOut types are mismatched");
1790 "Reference Lstm: input and output types are mismatched");
1793 "Reference Lstm: input and InputToForgetWeights types are mismatched");
1795 "Reference Lstm: input and InputToCellWeights types are mismatched");
1797 "Reference Lstm: input and InputToOutputWeights types are mismatched");
1799 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1801 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1803 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1805 "Reference Lstm: input and ForgetGateBias types are mismatched");
1807 "Reference Lstm: input and CellBias types are mismatched");
1809 "Reference Lstm: input and OutputGateBias types are mismatched");
1813 "Reference Lstm: input and InputToInputWeights types are mismatched");
1815 reasonIfUnsupported,
1816 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1818 "Reference Lstm: input and InputGateBias types are mismatched");
1822 reasonIfUnsupported,
1823 "Reference Lstm: input and CellToInputWeights types are mismatched");
1829 "Reference Lstm: input and CellToForgetWeights types are mismatched");
1831 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1836 "Reference Lstm: input and mProjectionWeights types are mismatched");
1840 "Reference Lstm: input and ProjectionBias types are mismatched");
1848 reasonIfUnsupported,
1849 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1852 reasonIfUnsupported,
1853 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1855 reasonIfUnsupported,
1856 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1858 reasonIfUnsupported,
1859 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1870 bool supported =
true;
1872 std::array<DataType,7> supportedTypes = {
1883 "Reference maximum: input 0 is not a supported type.");
1886 "Reference maximum: input 1 is not a supported type.");
1889 "Reference maximum: output is not a supported type.");
1892 "Reference maximum: input 0 and Input 1 types are mismatched");
1895 "Reference maximum: input and output types are mismatched");
1898 "Reference maximum: shapes are not suitable for implicit broadcast.");
1908 bool supported =
true;
1909 std::string meanLayerStr =
"Mean";
1910 std::string outputTensorStr =
"output";
1912 std::array<DataType,6> supportedTypes =
1923 "Reference Mean: input type not supported.");
1926 "Reference Mean: input and output types are mismatched");
1931 reasonIfUnsupported,
1934 meanLayerStr, outputTensorStr).data());
1936 else if (descriptor.
m_Axis.empty())
1939 reasonIfUnsupported,
1941 meanLayerStr, outputTensorStr).data());
1950 reasonIfUnsupported,
1952 meanLayerStr, outputTensorStr).data());
1957 reasonIfUnsupported,
1959 meanLayerStr, outputTensorStr).data());
1970 bool supported =
true;
1972 std::array<DataType,7> supportedTypes =
1984 "Reference MemCopy: input type not supported");
1987 "Reference MemCopy: output type not supported");
1990 "Reference MemCopy: input and output types are mismatched");
2000 bool supported =
true;
2002 std::array<DataType,7> supportedTypes = {
2013 "Reference minimum: input 0 is not a supported type.");
2016 "Reference minimum: input 1 is not a supported type.");
2019 "Reference minimum: output is not a supported type.");
2022 "Reference minimum: input 0 and Input 1 types are mismatched");
2025 "Reference minimum: input and output types are mismatched");
2028 "Reference minimum: shapes are not suitable for implicit broadcast.");
2038 bool supported =
true;
2040 std::array<DataType,7> supportedTypes = {
2051 "Reference multiplication: input 0 is not a supported type.");
2054 "Reference multiplication: input 1 is not a supported type.");
2057 "Reference multiplication: output is not a supported type.");
2060 "Reference multiplication: input 0 and Input 1 types are mismatched");
2063 "Reference multiplication: input and output types are mismatched");
2066 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2079 std::array<DataType, 6> supportedTypes =
2089 bool supported =
true;
2092 "Reference normalization: input type not supported.");
2095 "Reference normalization: output type not supported.");
2098 "Reference normalization: input and output shapes have different " 2099 "num total elements.");
2116 bool supported =
true;
2119 std::array<DataType,6> supportedTypes =
2130 "Reference pad: input is not a supported type.");
2133 "Reference pad: output is not a supported type.");
2136 "Reference pad: input and output types are mismatched.");
2147 bool supported =
true;
2150 std::array<DataType, 6> supportedTypes =
2161 "Reference permute: input is not a supported type.");
2164 "Reference permute: output is not a supported type.");
2167 "Reference permute: input and output types are mismatched.");
2178 bool supported =
true;
2181 std::array<DataType,6> supportedTypes =
2192 "Reference poolind2d: input is not a supported type.");
2195 "Reference poolind2d: output is not a supported type.");
2198 "Reference poolind2d: input and output types are mismatched.");
2209 bool supported =
true;
2212 std::array<DataType,6> supportedTypes =
2223 "Reference poolind3d: input is not a supported type.");
2226 "Reference poolind3d: output is not a supported type.");
2229 "Reference poolind3d: input and output types are mismatched.");
2263 bool supported =
true;
2266 std::array<DataType,7> supportedInputTypes = {
2277 "Reference quantize: input type not supported.");
2280 std::array<DataType,4> supportedOutputTypes = {
2287 "Reference quantize: output type not supported.");
2290 "Reference quantize: input and output shapes have different num total elements.");
2301 std::array<DataType,1> supportedOutputTypes =
2307 "Reference rank: input type not supported.");
2316 bool supported =
true;
2317 std::array<DataType,7> supportedTypes =
2329 "Reference Reduce: input type not supported");
2332 "Reference Reduce: output type not supported");
2335 "Reference Reduce: input and output types not matching");
2348 std::array<DataType,8> supportedOutputTypes =
2361 "Reference reshape: input type not supported.");
2370 bool supported =
true;
2371 std::array<DataType,6> supportedTypes =
2382 "Reference Resize: input type not supported");
2385 "Reference Resize: output type not supported");
2388 "Reference Resize: input and output types not matching");
2398 bool supported =
true;
2400 std::array<DataType, 1> supportedTypes =
2406 "Reference Shape: output type not supported");
2417 bool supported =
true;
2419 std::array<DataType, 5> supportedTypes =
2429 "Reference Slice: input type not supported");
2432 "Reference Slice: output type not supported");
2435 "Reference Slice: input and output types are mismatched");
2446 bool supported =
true;
2447 std::array<DataType,7> supportedTypes =
2459 "Reference Softmax: output type not supported");
2462 "Reference Softmax: input type not supported");
2465 "Reference Softmax: input type not supported");
2476 bool supported =
true;
2477 std::array<DataType,6> supportedTypes =
2488 "Reference SpaceToBatchNd: input type not supported");
2491 "Reference SpaceToBatchNd: output type not supported");
2494 "Reference SpaceToBatchNd: input and output types are mismatched");
2506 bool supported =
true;
2508 std::array<DataType,6> supportedTypes =
2519 "Reference SpaceToDepth: input type not supported");
2522 "Reference SpaceToDepth: output type not supported");
2525 "Reference SpaceToDepth: input and output types are mismatched");
2531 const std::vector<std::reference_wrapper<TensorInfo>>&
outputs,
2536 bool supported =
true;
2537 std::array<DataType,6> supportedTypes =
2548 "Reference splitter: output type not supported");
2552 "Reference splitter: input type not supported");
2555 "Reference splitter: input and output types mismatched.");
2568 bool supported =
true;
2569 std::array<DataType,7> supportedTypes =
2581 "Reference stack: output type not supported");
2586 "Reference stack: input type not supported");
2589 "Reference stack: input and output types mismatched.");
2601 bool supported =
true;
2603 std::array<DataType,5> supportedTypes =
2613 "Reference StridedSlice: input type not supported");
2616 "Reference StridedSlice: output type not supported");
2619 "Reference StridedSlice: input and output types are mismatched");
2629 bool supported =
true;
2631 std::array<DataType,7> supportedTypes = {
2642 "Reference subtraction: input 0 is not a supported type.");
2645 "Reference subtraction: input 1 is not a supported type.");
2648 "Reference subtraction: output is not a supported type.");
2651 "Reference subtraction: input 0 and Input 1 types are mismatched");
2654 "Reference subtraction: input and output types are mismatched");
2657 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2667 bool supported =
true;
2669 std::array<DataType, 6> supportedTypes
2680 "PReLU: input is not a supported type.");
2683 "PReLU: alpha is not a supported type.");
2686 "PReLU: output is not a supported type.");
2689 "PReLU: input, alpha and output types are mismatched");
2692 "PReLU: shapes are not suitable for implicit broadcast");
2705 bool supported =
true;
2707 std::array<DataType,7> supportedTypes =
2719 "Reference TransposeConvolution2d: input is not a supported type.");
2722 "Reference TransposeConvolution2d: output is not a supported type.");
2725 "Reference TransposeConvolution2d: input and output types mismatched.");
2731 std::array<DataType, 3> supportedWeightTypes =
2739 "Reference TransposeConvolution2d: weights type not supported for " 2740 "quantized input.");
2745 "Reference TransposeConvolution2d: weights is not a supported type.");
2748 "Reference TransposeConvolution2d: input and weights types mismatched.");
2753 std::array<DataType,4> biasesSupportedTypes =
2761 "Reference TransposeConvolution2d: biases is not a supported type.");
2773 bool supported =
true;
2776 std::array<DataType, 6> supportedTypes =
2787 "Reference transpose: input is not a supported type.");
2790 "Reference transpose: output is not a supported type.");
2793 "Reference transpose: input and output types are mismatched.");
2815 bool supported =
true;
2817 std::array<DataType, 2> supportedTypes =
2823 std::array<DataType, 2> supportedWeightTypes =
2829 std::array<DataType, 3> supportedBiasTypes =
2838 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2840 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
2844 reasonIfUnsupported,
2845 "Reference UnidirectionalSequenceLstm: InputToForgetWeights " 2846 "is not a supported type.");
2848 reasonIfUnsupported,
2849 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2851 reasonIfUnsupported,
2852 "Reference UnidirectionalSequenceLstm: InputToOutputWeights " 2853 "is not a supported type.");
2855 reasonIfUnsupported,
2856 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights " 2857 "is not a supported type.");
2859 reasonIfUnsupported,
2860 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights " 2861 "is not a supported type.");
2863 reasonIfUnsupported,
2864 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " 2865 "is not a supported type.");
2868 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2870 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2872 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
2876 reasonIfUnsupported,
2877 "Reference UnidirectionalSequenceLstm: InputToInputWeights " 2878 "is not a supported type.");
2880 reasonIfUnsupported,
2881 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " 2882 "is not a supported type.");
2884 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
2888 reasonIfUnsupported,
2889 "Reference UnidirectionalSequenceLstm: CellToInputWeights " 2890 "is not a supported type.");
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: CellToForgetWeights " 2898 "is not a supported type.");
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: CellToOutputWeights " 2902 "is not a supported type.");
2907 reasonIfUnsupported,
2908 "Reference UnidirectionalSequenceLstm: ProjectionWeights " 2909 "is not a supported type.");
2913 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types " 2922 reasonIfUnsupported,
2923 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights " 2924 "is not a supported type.");
2927 reasonIfUnsupported,
2928 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights " 2929 "is not a supported type.");
2931 reasonIfUnsupported,
2932 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights " 2933 "is not a supported type.");
2935 reasonIfUnsupported,
2936 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights " 2937 "is not a supported type.");
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool IsComparisonSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const ComparisonDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
UnaryOperation m_Operation
Specifies the elementwiseUnary operation to execute.
bool IsReshapeSupported(const TensorInfo &input, const TensorInfo &output, const ReshapeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ViewsDescriptor for the SplitterLayer.
const TensorInfo const TensorInfo & anchors
bool IsSoftmaxSupported(const TensorInfo &input, const TensorInfo &output, const SoftmaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A TransposeConvolution2dDescriptor for the TransposeConvolution2dLayer.
const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo const LstmDescriptor const LstmInputParamsInfo & paramsInfo
bool IsPermuteSupported(const TensorInfo &input, const TensorInfo &output, const PermuteDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsPadSupported(const TensorInfo &input, const TensorInfo &output, const PadDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & output
bool IsLogSoftmaxSupported(const TensorInfo &input, const TensorInfo &output, const LogSoftmaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported) const override
A ReshapeDescriptor for the ReshapeLayer.
bool IsMemImportSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsGatherSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const GatherDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsUnidirectionalSequenceLstmSupported(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const UnidirectionalSequenceLstmDescriptor &descriptor, const LstmInputParamsInfo ¶msInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvertFp32ToFp16Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ComparisonDescriptor for the ComparisonLayer.
bool IsDilatedDepthwiseConvolutionSupported(const TensorInfo &input, const TensorInfo &output, const DepthwiseConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo & gamma
bool IsQuantizedLstmSupported(const TensorInfo &input, const TensorInfo &previousCellStateIn, const TensorInfo &previousOutputIn, const TensorInfo &cellStateOut, const TensorInfo &output, const QuantizedLstmInputParamsInfo ¶msInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const std::vector< std::reference_wrapper< TensorInfo > > & outputs
bool IsSliceSupported(const TensorInfo &input, const TensorInfo &output, const SliceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsStackSupported(const std::vector< const TensorInfo *> &inputs, const TensorInfo &output, const StackDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A Convolution2dDescriptor for the Convolution2dLayer.
bool IsPreluSupported(const TensorInfo &input, const TensorInfo &alpha, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvertBf16ToFp32Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSplitterSupported(const TensorInfo &input, const std::vector< std::reference_wrapper< TensorInfo >> &outputs, const ViewsDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const ActivationDescriptor Optional< std::string & > reasonIfUnsupported
bool IsLogicalBinarySupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const LogicalBinaryDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported) const override
bool IsQLstmSupported(const TensorInfo &input, const TensorInfo &previousOutputIn, const TensorInfo &previousCellStateIn, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const QLstmDescriptor &descriptor, const LstmInputParamsInfo ¶msInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsL2NormalizationSupported(const TensorInfo &input, const TensorInfo &output, const L2NormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvolution3dSupported(const TensorInfo &input, const TensorInfo &output, const Convolution3dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A LogicalBinaryDescriptor for the LogicalBinaryLayer.
const TensorInfo & scores
bool IsDepthToSpaceSupported(const TensorInfo &input, const TensorInfo &output, const DepthToSpaceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const TensorInfo const TensorInfo const TensorInfo & detectionClasses
bool IsGatherNdSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const
bool IsBatchNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const TensorInfo &mean, const TensorInfo &var, const TensorInfo &beta, const TensorInfo &gamma, const BatchNormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const TensorInfo const ActivationDescriptor & descriptor
A SpaceToDepthDescriptor for the SpaceToDepthLayer.
bool IsDepthwiseConvolutionSupported(const TensorInfo &input, const TensorInfo &output, const DepthwiseConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFakeQuantizationSupported(const TensorInfo &input, const FakeQuantizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsBatchToSpaceNdSupported(const TensorInfo &input, const TensorInfo &output, const BatchToSpaceNdDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConcatSupported(const std::vector< const TensorInfo *> inputs, const TensorInfo &output, const OriginsDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A BatchToSpaceNdDescriptor for the BatchToSpaceNdLayer.
const TensorInfo & outputStateIn
const TensorInfo const TensorInfo & previousCellStateIn
const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo & numDetections
bool IsTransposeConvolution2dSupported(const TensorInfo &input, const TensorInfo &output, const TransposeConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ResizeBilinearDescriptor for the ResizeBilinearLayer.
Base class for all descriptors.
std::vector< unsigned int > m_Axis
Values for the dimensions to reduce.
A StackDescriptor for the StackLayer.
constexpr bool IsQuantized8BitType(DataType dataType)
bool IsOutputSupported(const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFullyConnectedSupported(const TensorInfo &input, const TensorInfo &output, const TensorInfo &weights, const TensorInfo &biases, const FullyConnectedDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsQuantizeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsResizeSupported(const TensorInfo &input, const TensorInfo &output, const ResizeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A PadDescriptor for the PadLayer.
bool IsConstantSupported(const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsArgMinMaxSupported(const TensorInfo &input, const TensorInfo &output, const ArgMinMaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const TensorInfo & cellStateIn
bool IsSpaceToBatchNdSupported(const TensorInfo &input, const TensorInfo &output, const SpaceToBatchNdDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
An LstmDescriptor for the LstmLayer.
bool IsPooling3dSupported(const TensorInfo &input, const TensorInfo &output, const Pooling3dDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool m_KeepDims
Enable/disable keep dimensions. If true, then the reduced dimensions that are of length 1 are kept...
bool IsMinimumSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A L2NormalizationDescriptor for the L2NormalizationLayer.
An ArgMinMaxDescriptor for ArgMinMaxLayer.
DataType GetDataType() const
An OriginsDescriptor for the ConcatLayer.
A ReduceDescriptor for the REDUCE operators.
bool has_value() const noexcept
A FullyConnectedDescriptor for the FullyConnectedLayer.
bool IsRankSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool m_BiasEnabled
Enable/disable bias.
const TensorInfo const TensorInfo const TensorInfo const TensorInfo & outputStateOut
A FakeQuantizationDescriptor for the FakeQuantizationLayer.
const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo & cellStateOut
bool IsConvertFp16ToFp32Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A GatherDescriptor for the GatherLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
const TensorInfo const TensorInfo const TensorInfo const TensorInfo & beta
#define ARMNN_ASSERT(COND)
bool IsMeanSupported(const TensorInfo &input, const TensorInfo &output, const MeanDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A QLstmDescriptor for the QLstmLayer.
bool IsSpaceToDepthSupported(const TensorInfo &input, const TensorInfo &output, const SpaceToDepthDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsMergeSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsAdditionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsStridedSliceSupported(const TensorInfo &input, const TensorInfo &output, const StridedSliceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSubtractionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
An ActivationDescriptor for the ActivationLayer.
bool IsConvertFp32ToBf16Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFloorSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsActivationSupported(const TensorInfo &input, const TensorInfo &output, const ActivationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
min(a, max(b, input)) ReLu1 & ReLu6.
bool IsDebugSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A SliceDescriptor for the SliceLayer.
A Convolution3dDescriptor for the Convolution3dLayer.
const TensorInfo & previousOutputIn
A Pooling3dDescriptor for the Pooling3dLayer.
bool IsDetectionPostProcessSupported(const TensorInfo &boxEncodings, const TensorInfo &scores, const TensorInfo &anchors, const TensorInfo &detectionBoxes, const TensorInfo &detectionClasses, const TensorInfo &detectionScores, const TensorInfo &numDetections, const DetectionPostProcessDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A SpaceToBatchNdDescriptor for the SpaceToBatchNdLayer.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
bool IsLayerSupported(const LayerType &type, const std::vector< TensorInfo > &infos, const BaseDescriptor &descriptor, const Optional< LstmInputParamsInfo > &lstmParamsInfo, const Optional< QuantizedLstmInputParamsInfo > &, Optional< std::string &> reasonIfUnsupported) const override
const TensorInfo const TensorInfo const TensorInfo const TensorInfo const TensorInfo & detectionScores
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
int32_t m_Axis
The axis in params to gather indices from.
A ElementwiseUnaryDescriptor for the ElementwiseUnaryLayer.
bool IsElementwiseUnarySupported(const TensorInfo &input, const TensorInfo &output, const ElementwiseUnaryDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsMultiplicationSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const NormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsDequantizeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const Convolution2dDescriptor const TensorInfo const Optional< TensorInfo > & biases
A MeanDescriptor for the MeanLayer.
const TensorInfo const TensorInfo const TensorInfo & detectionBoxes
bool IsMemCopySupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool m_LayerNormEnabled
Enable/disable layer normalization.
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
A TransposeDescriptor for the TransposeLayer.
A StridedSliceDescriptor for the StridedSliceLayer.
bool IsInputSupported(const TensorInfo &input, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsChannelShuffleSupported(const TensorInfo &input, const TensorInfo &output, const ChannelShuffleDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsDivisionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & input1
bool IsLstmSupported(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &scratchBuffer, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const LstmDescriptor &descriptor, const LstmInputParamsInfo ¶msInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFillSupported(const TensorInfo &input, const TensorInfo &output, const FillDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsShapeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A Pooling2dDescriptor for the Pooling2dLayer.
A NormalizationDescriptor for the NormalizationLayer.
bool IsTransposeSupported(const TensorInfo &input, const TensorInfo &output, const TransposeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsCastSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const TensorInfo const TensorInfo & scratchBuffer
An InstanceNormalizationDescriptor for InstanceNormalizationLayer.
A ChannelShuffleDescriptor for the ChannelShuffle operator.
bool IsConvolution2dSupported(const TensorInfo &input, const TensorInfo &output, const Convolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsInstanceNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const InstanceNormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
unsigned int GetNumDimensions() const
bool IsSupportedForDataTypeGeneric(Optional< std::string &> reasonIfUnsupported, DataType dataType, Float16Func float16FuncPtr, Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, Int32Func int32FuncPtr, BooleanFunc booleanFuncPtr, Params &&... params)
A SoftmaxDescriptor for the SoftmaxLayer.
bool CheckSupportRule(F rule, Optional< std::string &> reasonIfUnsupported, const char *reason)
const TensorInfo const Convolution2dDescriptor const TensorInfo & weights
ActivationFunction m_Function
The activation function to use (Sigmoid, TanH, Linear, ReLu, BoundedReLu, SoftReLu, LeakyReLu, Abs, Sqrt, Square, Elu).
bool IsMaximumSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A DepthwiseConvolution2dDescriptor for the DepthwiseConvolution2dLayer.
A FillDescriptor for the FillLayer.
bool IsPooling2dSupported(const TensorInfo &input, const TensorInfo &output, const Pooling2dDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A BatchNormalizationDescriptor for the BatchNormalizationLayer.
bool IsReduceSupported(const TensorInfo &input, const TensorInfo &output, const ReduceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo const TensorInfo & mean
A PermuteDescriptor for the PermuteLayer.
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below...