ArmNN
 21.11
RefLayerSupport.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLayerSupport.hpp"
7 
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/Descriptors.hpp>
13 
14 #include <LayerSupportCommon.hpp>
16 
17 #include <vector>
18 #include <array>
19 
20 namespace armnn
21 {
22 
23 namespace
24 {
25 
26 template<typename Float32Func, typename Uint8Func, typename ... Params>
27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28  DataType dataType,
29  Float32Func floatFuncPtr,
30  Uint8Func uint8FuncPtr,
31  Params&&... params)
32 {
33  return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34  dataType,
35  &FalseFunc<Params...>,
36  floatFuncPtr,
37  uint8FuncPtr,
38  &FalseFunc<Params...>,
39  &FalseFunc<Params...>,
40  std::forward<Params>(params)...);
41 }
42 
43 } // anonymous namespace
44 
45 namespace
46 {
47 
48 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49  unsigned int actual,
50  std::string& layerStr,
51  std::string& tensorName)
52 {
53  std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54  " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55 
56  return errorMsg;
57 }
58 
59 } // anonymous namespace
60 
62  const TensorInfo& output,
63  const ActivationDescriptor& descriptor,
64  Optional<std::string&> reasonIfUnsupported) const
65 {
66  bool supported = true;
67 
68  // Define supported types.
69  std::array<DataType,6> supportedTypes = {
76  };
77 
78  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
79  "Reference activation: input type not supported.");
80 
81  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
82  "Reference activation: output type not supported.");
83 
84  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
85  "Reference activation: input and output types mismatched.");
86 
87  supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
88  "Reference activation: input and output shapes are of different rank.");
89 
90 
91  struct ActivationFunctionSupported : public Rule
92  {
93  ActivationFunctionSupported(const ActivationDescriptor& desc)
94  {
95  switch(desc.m_Function)
96  {
109  {
110  m_Res = true;
111  break;
112  }
113  default:
114  {
115  m_Res = false;
116  break;
117  }
118  }
119  }
120  };
121 
122  // Function is supported
123  supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
124  "Reference activation: function not supported.");
125 
126  return supported;
127 }
128 
130  const TensorInfo& input1,
131  const TensorInfo& output,
132  Optional<std::string&> reasonIfUnsupported) const
133 {
134  bool supported = true;
135 
136  std::array<DataType,7> supportedTypes = {
144  };
145 
146  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
147  "Reference addition: input 0 is not a supported type.");
148 
149  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
150  "Reference addition: input 1 is not a supported type.");
151 
152  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
153  "Reference addition: output is not a supported type.");
154 
155  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
156  "Reference addition: input 0 and Input 1 types are mismatched");
157 
158  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
159  "Reference addition: input and output types are mismatched");
160 
161  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
162  "Reference addition: shapes are not suitable for implicit broadcast.");
163 
164  return supported;
165 }
166 
168  const armnn::ArgMinMaxDescriptor &descriptor,
169  armnn::Optional<std::string &> reasonIfUnsupported) const
170 {
171  IgnoreUnused(descriptor);
172 
173  std::array<DataType, 8> supportedInputTypes =
174  {
183  };
184 
185  std::array<DataType,2> supportedOutputTypes = {
187  DataType::Signed64
188  };
189 
190  bool supported = true;
191 
192  supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
193  "Reference ArgMinMax: input is not a supported type.");
194  supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
195  "Reference ArgMinMax: output type not supported");
196 
197  return supported;
198 }
199 
201  const TensorInfo& output,
202  const TensorInfo& mean,
203  const TensorInfo& variance,
204  const TensorInfo& beta,
205  const TensorInfo& gamma,
206  const BatchNormalizationDescriptor& descriptor,
207  Optional<std::string&> reasonIfUnsupported) const
208 {
209  IgnoreUnused(descriptor);
210 
211  std::array<DataType, 6> supportedTypes =
212  {
219  };
220 
221  bool supported = true;
222 
223  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
224  "Reference batch normalization: input is not a supported type.");
225 
226  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
227  "Reference batch normalization: output is not a supported type.");
228 
229  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
230  "Reference batch normalization: input and output types are mismatched");
231 
232  supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
233  "Reference batch normalization: mean is not a supported type.");
234 
235  supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
236  "Reference batch normalization: variance is not a supported type.");
237 
238  supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
239  "Reference batch normalization: beta is not a supported type.");
240 
241  supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
242  "Reference batch normalization: gamma is not a supported type.");
243 
244  return supported;
245 }
246 
248  const TensorInfo& output,
249  const BatchToSpaceNdDescriptor& descriptor,
250  Optional<std::string&> reasonIfUnsupported) const
251 {
252  IgnoreUnused(descriptor);
253 
254  bool supported = true;
255 
256  std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
257  std::string inputTensorStr = "input";
258  std::string outputTensorStr = "output";
259 
260  // Define supported types.
261  std::array<DataType,6> supportedTypes =
262  {
269  };
270 
271  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
272  "Reference BatchToSpaceNd: input type not supported.");
273 
274  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
275  "Reference BatchToSpaceNd: output type not supported.");
276 
277  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
278  "Reference BatchToSpaceNd: input and output types mismatched.");
279 
280  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
281  reasonIfUnsupported,
282  CreateIncorrectDimensionsErrorMsg(4,
283  output.GetNumDimensions(),
284  batchToSpaceNdLayerStr,
285  outputTensorStr).data());
286 
287  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
288  reasonIfUnsupported,
289  CreateIncorrectDimensionsErrorMsg(4,
290  input.GetNumDimensions(),
291  batchToSpaceNdLayerStr,
292  inputTensorStr).data());
293 
294  return supported;
295 }
296 
298  const TensorInfo& output,
299  Optional<std::string&> reasonIfUnsupported) const
300 {
301  std::array<DataType, 9> supportedInputTypes =
302  {
311  };
312 
313  bool supported = true;
314  supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
315  "Reference cast: input is not a supported type");
316 
317 
318  supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
319  "Reference cast: output is not a supported type");
320 
321  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
322  "Reference cast: input and output shapes have different number of total elements");
323 
324  return supported;
325 }
326 
328  const TensorInfo& output,
329  const ChannelShuffleDescriptor& descriptor,
330  Optional<std::string&> reasonIfUnsupported) const
331 {
332  IgnoreUnused(descriptor);
333  bool supported = true;
334 
335  // Define supported output and inputs types.
336  std::array<DataType, 7> supportedTypes =
337  {
345  };
346 
347  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
348  "Reference ChannelShuffle: input is not a supported type.");
349 
350  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
351  "Reference ChannelShuffle: output is not a supported type.");
352 
353  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
354  "Reference ChannelShuffle: input and output types are mismatched.");
355 
356  return supported;
357 }
358 
359 
361  const TensorInfo& input1,
362  const TensorInfo& output,
363  const ComparisonDescriptor& descriptor,
364  Optional<std::string&> reasonIfUnsupported) const
365 {
366  IgnoreUnused(descriptor);
367  std::array<DataType, 8> supportedInputTypes =
368  {
377  };
378 
379  bool supported = true;
380  supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
381  "Reference comparison: input 0 is not a supported type");
382 
383  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
384  "Reference comparison: input 0 and Input 1 types are mismatched");
385 
386  supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
387  "Reference comparison: output is not of type Boolean");
388 
389  return supported;
390 }
391 
392 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
393  const TensorInfo& output,
394  const ConcatDescriptor& descriptor,
395  Optional<std::string&> reasonIfUnsupported) const
396 {
397  IgnoreUnused(descriptor);
398 
399  bool supported = true;
400  std::array<DataType,6> supportedTypes =
401  {
408  };
409 
410  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
411  "Reference concatenation: output type not supported");
412  for (const TensorInfo* input : inputs)
413  {
414  ARMNN_ASSERT(input != nullptr);
415  supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
416  "Reference concatenation: input type not supported");
417 
418  supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
419  "Reference concatenation: input and output types mismatched.");
420  }
421 
422  return supported;
423 }
424 
426  Optional<std::string&> reasonIfUnsupported) const
427 {
428  std::array<DataType,8> supportedTypes =
429  {
438  };
439 
440  return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
441  "Reference constant: output is not a supported type.");
442 }
443 
445  const TensorInfo& output,
446  Optional<std::string&> reasonIfUnsupported) const
447 {
448  bool supported = true;
449 
450  supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
451  "Reference for ConvertBf16ToFp32 layer: input type not supported");
452 
453  supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
454  "Reference for ConvertBf16ToFp32 layer: output type not supported");
455 
456  return supported;
457 }
458 
460  const TensorInfo& output,
461  Optional<std::string&> reasonIfUnsupported) const
462 {
463  return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
464  input.GetDataType(),
465  &TrueFunc<>,
466  &FalseInputFuncF32<>,
467  &FalseFuncU8<>,
468  &FalseFuncI32<>,
469  &FalseFuncU8<>) &&
470  IsSupportedForDataTypeGeneric(reasonIfUnsupported,
471  output.GetDataType(),
472  &FalseOutputFuncF16<>,
473  &TrueFunc<>,
474  &FalseFuncU8<>,
475  &FalseFuncI32<>,
476  &FalseFuncU8<>));
477 }
478 
480  const TensorInfo& output,
481  Optional<std::string&> reasonIfUnsupported) const
482 {
483  bool supported = true;
484 
485  supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
486  "Reference for ConvertFp32ToBf16 layer: input type not supported");
487 
488  supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
489  "Reference for ConvertFp32ToBf16 layer: output type not supported");
490 
491  return supported;
492 }
493 
495  const TensorInfo& output,
496  Optional<std::string&> reasonIfUnsupported) const
497 {
498  return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
499  input.GetDataType(),
500  &FalseInputFuncF16<>,
501  &TrueFunc<>,
502  &FalseFuncU8<>,
503  &FalseFuncI32<>,
504  &FalseFuncU8<>) &&
505  IsSupportedForDataTypeGeneric(reasonIfUnsupported,
506  output.GetDataType(),
507  &TrueFunc<>,
508  &FalseOutputFuncF32<>,
509  &FalseFuncU8<>,
510  &FalseFuncI32<>,
511  &FalseFuncU8<>));
512 }
513 
515  const TensorInfo& output,
516  const Convolution2dDescriptor& descriptor,
517  const TensorInfo& weights,
518  const Optional<TensorInfo>& biases,
519  Optional<std::string&> reasonIfUnsupported) const
520 {
521  bool supported = true;
522 
523  // Define supported types.
524  std::array<DataType,7> supportedTypes =
525  {
533  };
534 
535  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
536  "Reference Convolution2d: input is not a supported type.");
537 
538  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
539  "Reference Convolution2d: output is not a supported type.");
540 
541  // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
542  if (input.GetDataType() == DataType::BFloat16)
543  {
544  if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
545  {
546  reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
547  supported = false;
548  }
549  }
550  else
551  {
552  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
553  "Reference Convolution2d: input and output types mismatched.");
554  }
555 
556  const DataType inputType = input.GetDataType();
557  if (IsQuantized8BitType(inputType))
558  {
559  std::array<DataType, 3> supportedWeightTypes =
560  {
564  };
565 
566  supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
567  "Reference Convolution2d: weights type not supported for quantized input.");
568  }
569  else
570  {
571  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
572  "Reference Convolution2d: weights is not a supported type.");
573 
574  supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
575  "Reference Convolution2d: input and weights types mismatched.");
576  }
577 
578  if (biases.has_value())
579  {
580  std::array<DataType,4> biasesSupportedTypes =
581  {
586  };
587 
588  supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
589  "Reference Convolution2d: biases is not a supported type.");
590  }
591  IgnoreUnused(descriptor);
592 
593  return supported;
594 }
595 
597  const TensorInfo& output,
598  const Convolution3dDescriptor& descriptor,
599  const TensorInfo& weights,
600  const Optional<TensorInfo>& biases,
601  Optional<std::string&> reasonIfUnsupported) const
602 {
603  bool supported = true;
604 
605  // Define supported types.
606  std::array<DataType,7> supportedTypes =
607  {
615  };
616 
617  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
618  "Reference Convolution3d: input is not a supported type.");
619 
620  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
621  "Reference Convolution3d: output is not a supported type.");
622 
623  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
624  "Reference Convolution3d: input and output types mismatched.");
625 
626  const DataType inputType = input.GetDataType();
627  if (IsQuantized8BitType(inputType))
628  {
629  std::array<DataType, 3> supportedWeightTypes =
630  {
634  };
635 
636  supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
637  "Reference Convolution3d: weights type not supported for quantized input.");
638  }
639  else
640  {
641  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
642  "Reference Convolution3d: weights is not a supported type.");
643 
644  supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
645  "Reference Convolution3d: input and weights types mismatched.");
646  }
647 
648  if (biases.has_value())
649  {
650  std::array<DataType,4> biasesSupportedTypes =
651  {
656  };
657 
658  supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
659  "Reference Convolution3d: biases is not a supported type.");
660  }
661  IgnoreUnused(descriptor);
662 
663  return supported;
664 }
665 
667  const TensorInfo& output,
668  Optional<std::string&> reasonIfUnsupported) const
669 {
670  bool supported = true;
671 
672  std::array<DataType, 8> supportedTypes =
673  {
682  };
683 
684  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
685  "Reference for Debug layer: input type not supported");
686 
687  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
688  "Reference for Debug layer: output type not supported");
689 
690  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
691  "Reference for Debug layer: input and output types are mismatched");
692 
693  return supported;
694 }
695 
697  const TensorInfo& output,
698  const DepthToSpaceDescriptor& descriptor,
699  Optional<std::string&> reasonIfUnsupported) const
700 {
701  IgnoreUnused(descriptor);
702  bool supported = true;
703 
704  std::array<DataType,6> supportedTypes =
705  {
712  };
713 
714  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
715  "Reference DepthToSpace: input type not supported");
716 
717  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
718  "Reference DepthToSpace: output type not supported");
719 
720  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
721  "Reference DepthToSpace: input and output types are mismatched");
722 
723  return supported;
724 }
725 
727  const TensorInfo& output,
728  const DepthwiseConvolution2dDescriptor& descriptor,
729  const TensorInfo& weights,
730  const Optional<TensorInfo>& biases,
731  Optional<std::string&> reasonIfUnsupported) const
732 {
733  IgnoreUnused(descriptor);
734  bool supported = true;
735 
736  // Define supported types.
737  std::array<DataType,7> supportedTypes =
738  {
746  };
747 
748  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
749  "Reference DepthwiseConvolution2d: input is not a supported type.");
750 
751  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
752  "Reference DepthwiseConvolution2d: output is not a supported type.");
753 
754  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
755  "Reference DepthwiseConvolution2d: input and output types mismatched.");
756 
757  const DataType inputType = input.GetDataType();
758  if (IsQuantized8BitType(inputType))
759  {
760  std::array<DataType, 3> supportedWeightTypes =
761  {
765  };
766 
767  supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
768  "Reference DepthwiseConvolution2d: weights type not supported for "
769  "quantized input.");
770  }
771  else
772  {
773  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
774  "Reference DepthwiseConvolution2d: weights is not a supported type.");
775 
776  supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
777  "Reference DepthwiseConvolution2d: input and weights types mismatched.");
778  }
779 
780  if (biases.has_value())
781  {
782  std::array<DataType,4> biasesSupportedTypes =
783  {
788  };
789  supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
790  "Reference DepthwiseConvolution2d: biases is not a supported type.");
791  }
792 
793  return supported;
794 
795 }
796 
798  const TensorInfo& output,
799  Optional<std::string&> reasonIfUnsupported) const
800 {
801  bool supported = true;
802 
803  std::array<DataType,4> supportedInputTypes = {
808  };
809 
810  supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
811  "Reference for Dequantize layer: input type not supported.");
812 
813  supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
814  "Reference for Dequantize layer: per-axis quantized input not supported.");
815 
816  std::array<DataType,3> supportedOutputTypes = {
820  };
821 
822  supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
823  "Reference for Dequantize layer: output type not supported.");
824 
825  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
826  "Reference for Dequantize layer: input/output shapes have different num total "
827  "elements.");
828 
829  return supported;
830 }
831 
833  const TensorInfo& scores,
834  const TensorInfo& anchors,
835  const TensorInfo& detectionBoxes,
836  const TensorInfo& detectionClasses,
837  const TensorInfo& detectionScores,
838  const TensorInfo& numDetections,
839  const DetectionPostProcessDescriptor& descriptor,
840  Optional<std::string&> reasonIfUnsupported) const
841 {
842  IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
843 
844  bool supported = true;
845 
846  std::array<DataType,6> supportedInputTypes =
847  {
854  };
855 
856  supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
857  "Reference DetectionPostProcess: input 0 is not a supported type.");
858 
859  supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
860  "Reference DetectionPostProcess: input 1 is not a supported type.");
861 
862  return supported;
863 }
864 
866  const TensorInfo& output,
867  const DepthwiseConvolution2dDescriptor& descriptor,
868  const TensorInfo& weights,
869  const Optional<TensorInfo>& biases,
870  Optional<std::string&> reasonIfUnsupported) const
871 {
872  return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
873 }
874 
876  const TensorInfo& input1,
877  const TensorInfo& output,
878  Optional<std::string&> reasonIfUnsupported) const
879 {
880  bool supported = true;
881 
882  std::array<DataType,7> supportedTypes = {
890  };
891 
892  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
893  "Reference division: input 0 is not a supported type.");
894 
895  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
896  "Reference division: input 1 is not a supported type.");
897 
898  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
899  "Reference division: output is not a supported type.");
900 
901  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
902  "Reference division: input 0 and Input 1 types are mismatched");
903 
904  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
905  "Reference division: input and output types are mismatched");
906 
907  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
908  "Reference division: shapes are not suitable for implicit broadcast.");
909 
910  return supported;
911 }
912 
914  const TensorInfo& output,
915  const ElementwiseUnaryDescriptor& descriptor,
916  Optional<std::string&> reasonIfUnsupported) const
917 {
918  IgnoreUnused(descriptor);
919 
920  std::array<DataType, 7> supportedTypes =
921  {
929  };
930 
931  std::array<DataType, 1> logicalSupportedTypes =
932  {
934  };
935 
936  bool supported = true;
937 
938  if (descriptor.m_Operation == UnaryOperation::LogicalNot)
939  {
940  supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
941  "Reference elementwise unary: input type not supported");
942 
943  supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
944  "Reference elementwise unary: output type not supported");
945  }
946  else
947  {
948  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
949  "Reference elementwise unary: input type not supported");
950 
951  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
952  "Reference elementwise unary: output type not supported");
953  }
954 
955  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
956  "Reference elementwise unary: input and output types not matching");
957 
958  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
959  "Reference elementwise unary: input and output shapes"
960  "have different number of total elements");
961 
962  return supported;
963 }
964 
966  const FakeQuantizationDescriptor& descriptor,
967  Optional<std::string&> reasonIfUnsupported) const
968 {
969  IgnoreUnused(descriptor);
970  bool supported = true;
971 
972  std::array<DataType,1> supportedTypes =
973  {
975  };
976 
977  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
978  "Reference fake quantization: input type not supported.");
979 
980  return supported;
981 }
982 
984  const TensorInfo& output,
985  const FillDescriptor& descriptor,
986  Optional<std::string&> reasonIfUnsupported) const
987 {
988  IgnoreUnused(descriptor);
989  IgnoreUnused(output);
990 
991  bool supported = true;
992 
993  std::array<DataType,3> supportedTypes =
994  {
998  };
999 
1000  supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
1001  "Reference Fill: input type not supported.");
1002 
1003  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1004  "Reference Fill: output type not supported.");
1005  return supported;
1006 }
1007 
1009  const TensorInfo& output,
1010  Optional<std::string&> reasonIfUnsupported) const
1011 {
1012  IgnoreUnused(output);
1013  bool supported = true;
1014 
1015  std::array<DataType,3> supportedTypes =
1016  {
1020  };
1021 
1022  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1023  "Reference Floor: input type not supported.");
1024 
1025  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1026  "Reference Floor: output type not supported.");
1027 
1028  return supported;
1029 }
1030 
1032  const TensorInfo& output,
1033  const TensorInfo& weights,
1034  const TensorInfo& biases,
1035  const FullyConnectedDescriptor& descriptor,
1036  Optional<std::string&> reasonIfUnsupported) const
1037 {
1038  bool supported = true;
1039 
1040  // Define supported types.
1041  std::array<DataType,6> supportedTypes =
1042  {
1049  };
1050 
1051  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1052  "Reference Fully Connected: input type not supported.");
1053 
1054  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1055  "Reference Fully Connected: output type not supported.");
1056 
1057  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1058  "Reference Fully Connected: weights type not supported.");
1059 
1060  // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1061  if (input.GetDataType() == DataType::BFloat16)
1062  {
1063  if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1064  {
1065  reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1066  supported = false;
1067  }
1068  }
1069  else
1070  {
1071  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1072  "Reference Fully Connected: input and output types mismatched.");
1073  }
1074 
1075  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1076  "Reference Fully Connected: weights is not a supported type.");
1077 
1078  supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1079  "Reference Fully Connected: input and weights types mismatched.");
1080 
1081  if (descriptor.m_BiasEnabled)
1082  {
1083  // Defined supported types for bias
1084  std::array<DataType, 5>
1085  supportedBiasTypes =
1086  {
1092  };
1093 
1094  supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1095  "Reference Fully Connected: bias type not supported.");
1096 
1097  supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1098  "Reference Fully Connected: bias and weight types mismatch.");
1099 
1100  supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1101  "Reference Fully Connected: bias type inferred from weights is incompatible.");
1102 
1103  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1104  "Reference Fully Connected: bias must have 1 dimension.");
1105 
1106  }
1107 
1108  return supported;
1109 }
1110 
1112  const armnn::TensorInfo& input1,
1113  const armnn::TensorInfo& output,
1114  const GatherDescriptor& descriptor,
1115  armnn::Optional<std::string&> reasonIfUnsupported) const
1116 {
1117  bool supported = true;
1118  std::array<DataType,7> supportedTypes =
1119  {
1127  };
1128 
1129  if (descriptor.m_Axis != 0)
1130  {
1131  reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1132  supported &= false;
1133  }
1134  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1135  "Reference Gather: input type not supported");
1136 
1137  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1138  "Reference Gather: output type not supported");
1139 
1140  supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1141  "Reference Gather: indices (input1) type not supported");
1142 
1143  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1144  "Reference Gather: input and output types not matching");
1145 
1146  return supported;
1147 }
1148 
1150  Optional<std::string&> /*reasonIfUnsupported*/) const
1151 {
1152  return true;
1153 }
1154 
1156  const TensorInfo& output,
1157  const InstanceNormalizationDescriptor& descriptor,
1158  Optional<std::string&> reasonIfUnsupported) const
1159 {
1160  IgnoreUnused(descriptor);
1161  // Define supported types
1162  std::array<DataType, 3> supportedTypes =
1163  {
1167  };
1168 
1169  bool supported = true;
1170 
1171  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1172  "Reference Instance Normalization: input type not supported.");
1173 
1174  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1175  "Reference Instance Normalization: output type not supported.");
1176 
1177  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1178  "Reference Instance Normalization: input and output types mismatched.");
1179 
1180  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1181  "Reference Instance Normalization: input and output shapes have different "
1182  "num total elements.");
1183 
1184  return supported;
1185 }
1186 
1188  const TensorInfo& output,
1189  const L2NormalizationDescriptor& descriptor,
1190  Optional<std::string&> reasonIfUnsupported) const
1191 {
1192  IgnoreUnused(descriptor);
1193  // Define supported types
1194  std::array<DataType, 6> supportedTypes =
1195  {
1202  };
1203 
1204  bool supported = true;
1205 
1206  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1207  "Reference L2normalization: input type not supported.");
1208 
1209  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1210  "Reference L2normalization: output type not supported.");
1211 
1212  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1213  "Reference L2normalization: input and output types mismatched.");
1214 
1215  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1216  "Reference L2normalization: input and output shapes have different "
1217  "num total elements.");
1218 
1219  return supported;
1220 }
1221 
1223  const TensorInfo& input1,
1224  const TensorInfo& output,
1225  const LogicalBinaryDescriptor& descriptor,
1226  Optional<std::string&> reasonIfUnsupported) const
1227 {
1228  IgnoreUnused(descriptor);
1229 
1230  std::array<DataType, 1> supportedTypes =
1231  {
1233  };
1234 
1235  bool supported = true;
1236  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1237  "Reference LogicalBinary: input 0 type not supported");
1238  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1239  "Reference LogicalBinary: input 1 type not supported");
1240 
1241  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1242  "Reference LogicalBinary: input and output types do not match");
1243 
1244  return supported;
1245 }
1246 
1248  const TensorInfo& output,
1249  const LogSoftmaxDescriptor& descriptor,
1250  Optional<std::string&> reasonIfUnsupported) const
1251 {
1252  IgnoreUnused(descriptor);
1253 
1254  std::array<DataType, 3> supportedTypes =
1255  {
1259  };
1260 
1261  bool supported = true;
1262  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1263  "Reference LogSoftmax: input type not supported");
1264 
1265  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1266  "Reference LogSoftmax: output type not supported");
1267 
1268  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1269  "Reference LogSoftmax: input and output types do not match");
1270 
1271  return supported;
1272 }
1273 
1275  const TensorInfo& outputStateIn,
1276  const TensorInfo& cellStateIn,
1277  const TensorInfo& scratchBuffer,
1278  const TensorInfo& outputStateOut,
1279  const TensorInfo& cellStateOut,
1280  const TensorInfo& output,
1281  const LstmDescriptor& descriptor,
1282  const LstmInputParamsInfo& paramsInfo,
1283  Optional<std::string&> reasonIfUnsupported) const
1284 {
1285  IgnoreUnused(descriptor);
1286  IgnoreUnused(paramsInfo);
1287 
1288  bool supported = true;
1289 
1290  std::array<DataType,3> supportedTypes = {
1294  };
1295 
1296  // check inputs and outputs
1297  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1298  "Reference Lstm: input is not a supported type.");
1299  supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1300  "Reference Lstm: input and outputStateIn types are mismatched");
1301  supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1302  "Reference Lstm: input and cellStateIn types are mismatched");
1303  supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1304  "Reference Lstm: input and scratchBuffer types are mismatched");
1305  supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1306  "Reference Lstm: input and outputStateOut types are mismatched");
1307  supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1308  "Reference Lstm: input and cellStateOut types are mismatched");
1309 
1310  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1311  "Reference Lstm: input and output types are mismatched");
1312  // check layer parameters
1313  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1314  "Reference Lstm: input and InputToForgetWeights types are mismatched");
1315  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1316  "Reference Lstm: input and InputToCellWeights types are mismatched");
1317  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1318  "Reference Lstm: input and InputToOutputWeights types are mismatched");
1319  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1320  "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1321  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1322  "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1323  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1324  "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1325  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1326  "Reference Lstm: input and ForgetGateBias types are mismatched");
1327  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1328  "Reference Lstm: input and CellBias types are mismatched");
1329  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1330  "Reference Lstm: input and OutputGateBias types are mismatched");
1331  if (!descriptor.m_CifgEnabled)
1332  {
1333  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1334  "Reference Lstm: input and InputToInputWeights types are mismatched");
1335  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1336  reasonIfUnsupported,
1337  "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1338  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1339  "Reference Lstm: input and InputGateBias types are mismatched");
1340  if (descriptor.m_PeepholeEnabled)
1341  {
1342  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1343  reasonIfUnsupported,
1344  "Reference Lstm: input and CellToInputWeights types are mismatched");
1345  }
1346  }
1347  if (descriptor.m_PeepholeEnabled)
1348  {
1349  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1350  "Reference Lstm: input and CellToForgetWeights types are mismatched");
1351  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1352  "Reference Lstm: input and CellToOutputWeights types are mismatched");
1353  }
1354  if (descriptor.m_ProjectionEnabled)
1355  {
1356  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1357  "Reference Lstm: input and mProjectionWeights types are mismatched");
1358  if (paramsInfo.m_ProjectionBias != nullptr)
1359  {
1360  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1361  "Reference Lstm: input and ProjectionBias types are mismatched");
1362  }
1363  }
1364  if (descriptor.m_LayerNormEnabled)
1365  {
1366  if (!descriptor.m_CifgEnabled)
1367  {
1368  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1369  reasonIfUnsupported,
1370  "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1371  }
1372  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1373  reasonIfUnsupported,
1374  "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1375  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1376  reasonIfUnsupported,
1377  "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1378  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1379  reasonIfUnsupported,
1380  "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1381  }
1382 
1383  return supported;
1384 }
1385 
1387  const TensorInfo& input1,
1388  const TensorInfo& output,
1389  Optional<std::string&> reasonIfUnsupported) const
1390 {
1391  bool supported = true;
1392 
1393  std::array<DataType,7> supportedTypes = {
1401  };
1402 
1403  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1404  "Reference maximum: input 0 is not a supported type.");
1405 
1406  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1407  "Reference maximum: input 1 is not a supported type.");
1408 
1409  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1410  "Reference maximum: output is not a supported type.");
1411 
1412  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1413  "Reference maximum: input 0 and Input 1 types are mismatched");
1414 
1415  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1416  "Reference maximum: input and output types are mismatched");
1417 
1418  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1419  "Reference maximum: shapes are not suitable for implicit broadcast.");
1420 
1421  return supported;
1422 }
1423 
1425  const TensorInfo& output,
1426  const MeanDescriptor& descriptor,
1427  Optional<std::string&> reasonIfUnsupported) const
1428 {
1429  bool supported = true;
1430  std::string meanLayerStr = "Mean";
1431  std::string outputTensorStr = "output";
1432 
1433  std::array<DataType,6> supportedTypes =
1434  {
1441  };
1442 
1443  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1444  "Reference Mean: input type not supported.");
1445 
1446  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1447  "Reference Mean: input and output types are mismatched");
1448 
1449  if (descriptor.m_KeepDims)
1450  {
1451  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1452  reasonIfUnsupported,
1453  CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1454  output.GetNumDimensions(),
1455  meanLayerStr, outputTensorStr).data());
1456  }
1457  else if (descriptor.m_Axis.empty())
1458  {
1459  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1460  reasonIfUnsupported,
1461  CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1462  meanLayerStr, outputTensorStr).data());
1463  }
1464  else
1465  {
1466  auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1467 
1468  if (outputDim > 0)
1469  {
1470  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1471  reasonIfUnsupported,
1472  CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1473  meanLayerStr, outputTensorStr).data());
1474  }
1475  else
1476  {
1477  supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1478  reasonIfUnsupported,
1479  CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1480  meanLayerStr, outputTensorStr).data());
1481  }
1482  }
1483 
1484  return supported;
1485 }
1486 
1488  const TensorInfo &output,
1489  Optional<std::string &> reasonIfUnsupported) const
1490 {
1491  bool supported = true;
1492 
1493  std::array<DataType,7> supportedTypes =
1494  {
1502  };
1503 
1504  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1505  "Reference MemCopy: input type not supported");
1506 
1507  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1508  "Reference MemCopy: output type not supported");
1509 
1510  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1511  "Reference MemCopy: input and output types are mismatched");
1512 
1513  return supported;
1514 }
1515 
1517  const TensorInfo& input1,
1518  const TensorInfo& output,
1519  Optional<std::string&> reasonIfUnsupported) const
1520 {
1521  bool supported = true;
1522 
1523  std::array<DataType,7> supportedTypes = {
1531  };
1532 
1533  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1534  "Reference minimum: input 0 is not a supported type.");
1535 
1536  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1537  "Reference minimum: input 1 is not a supported type.");
1538 
1539  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1540  "Reference minimum: output is not a supported type.");
1541 
1542  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1543  "Reference minimum: input 0 and Input 1 types are mismatched");
1544 
1545  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1546  "Reference minimum: input and output types are mismatched");
1547 
1548  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1549  "Reference minimum: shapes are not suitable for implicit broadcast.");
1550 
1551  return supported;
1552 }
1553 
1555  const TensorInfo& input1,
1556  const TensorInfo& output,
1557  Optional<std::string&> reasonIfUnsupported) const
1558 {
1559  bool supported = true;
1560 
1561  std::array<DataType,7> supportedTypes = {
1569  };
1570 
1571  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1572  "Reference multiplication: input 0 is not a supported type.");
1573 
1574  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1575  "Reference multiplication: input 1 is not a supported type.");
1576 
1577  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578  "Reference multiplication: output is not a supported type.");
1579 
1580  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1581  "Reference multiplication: input 0 and Input 1 types are mismatched");
1582 
1583  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584  "Reference multiplication: input and output types are mismatched");
1585 
1586  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1587  "Reference multiplication: shapes are not suitable for implicit broadcast.");
1588 
1589  return supported;
1590 }
1591 
1593  const TensorInfo& output,
1594  const NormalizationDescriptor& descriptor,
1595  Optional<std::string&> reasonIfUnsupported) const
1596 {
1597  IgnoreUnused(descriptor);
1598 
1599  // Define supported types
1600  std::array<DataType, 6> supportedTypes =
1601  {
1608  };
1609 
1610  bool supported = true;
1611 
1612  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1613  "Reference normalization: input type not supported.");
1614 
1615  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1616  "Reference normalization: output type not supported.");
1617 
1618  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1619  "Reference normalization: input and output shapes have different "
1620  "num total elements.");
1621 
1622  return supported;
1623 }
1624 
1626  Optional<std::string&> /*reasonIfUnsupported*/) const
1627 {
1628  return true;
1629 }
1630 
1632  const TensorInfo& output,
1633  const PadDescriptor& descriptor,
1634  Optional<std::string&> reasonIfUnsupported) const
1635 {
1636  IgnoreUnused(descriptor);
1637  bool supported = true;
1638 
1639  // Define supported output and inputs types.
1640  std::array<DataType,6> supportedTypes =
1641  {
1648  };
1649 
1650  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1651  "Reference pad: input is not a supported type.");
1652 
1653  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1654  "Reference pad: output is not a supported type.");
1655 
1656  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1657  "Reference pad: input and output types are mismatched.");
1658 
1659  return supported;
1660 }
1661 
1663  const TensorInfo& output,
1664  const PermuteDescriptor& descriptor,
1665  Optional<std::string&> reasonIfUnsupported) const
1666 {
1667  IgnoreUnused(descriptor);
1668  bool supported = true;
1669 
1670  // Define supported output and inputs types.
1671  std::array<DataType, 6> supportedTypes =
1672  {
1679  };
1680 
1681  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1682  "Reference permute: input is not a supported type.");
1683 
1684  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1685  "Reference permute: output is not a supported type.");
1686 
1687  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1688  "Reference permute: input and output types are mismatched.");
1689 
1690  return supported;
1691 }
1692 
1694  const TensorInfo& output,
1695  const Pooling2dDescriptor& descriptor,
1696  Optional<std::string&> reasonIfUnsupported) const
1697 {
1698  IgnoreUnused(descriptor);
1699  bool supported = true;
1700 
1701  // Define supported output and inputs types.
1702  std::array<DataType,6> supportedTypes =
1703  {
1710  };
1711 
1712  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1713  "Reference poolind2d: input is not a supported type.");
1714 
1715  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1716  "Reference poolind2d: output is not a supported type.");
1717 
1718  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1719  "Reference poolind2d: input and output types are mismatched.");
1720 
1721  return supported;
1722 }
1723 
1725  const TensorInfo& previousOutputIn,
1726  const TensorInfo& previousCellStateIn,
1727  const TensorInfo& outputStateOut,
1728  const TensorInfo& cellStateOut,
1729  const TensorInfo& output,
1730  const QLstmDescriptor& descriptor,
1731  const LstmInputParamsInfo& paramsInfo,
1732  Optional<std::string&> reasonIfUnsupported) const
1733 {
1734  IgnoreUnused(input);
1735  IgnoreUnused(previousOutputIn);
1736  IgnoreUnused(previousCellStateIn);
1737  IgnoreUnused(outputStateOut);
1738  IgnoreUnused(cellStateOut);
1739  IgnoreUnused(output);
1740  IgnoreUnused(descriptor);
1741  IgnoreUnused(paramsInfo);
1742 
1743  IgnoreUnused(reasonIfUnsupported);
1744 
1745  return true;
1746 }
1747 
1749  const TensorInfo& output,
1750  Optional<std::string&> reasonIfUnsupported) const
1751 {
1752  bool supported = true;
1753 
1754  // Define supported input types.
1755  std::array<DataType,7> supportedInputTypes = {
1763  };
1764 
1765  supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1766  "Reference quantize: input type not supported.");
1767 
1768  // Define supported output types.
1769  std::array<DataType,4> supportedOutputTypes = {
1773  DataType::QSymmS16
1774  };
1775  supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1776  "Reference quantize: output type not supported.");
1777 
1778  supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1779  "Reference quantize: input and output shapes have different num total elements.");
1780 
1781  return supported;
1782 }
1783 
1785  const TensorInfo& output,
1786  Optional<std::string&> reasonIfUnsupported) const
1787 {
1788  IgnoreUnused(input);
1789  // Define supported output types.
1790  std::array<DataType,1> supportedOutputTypes =
1791  {
1793  };
1794 
1795  return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1796  "Reference rank: input type not supported.");
1797 }
1798 
1800  const TensorInfo& output,
1801  const ReduceDescriptor& descriptor,
1802  Optional<std::string&> reasonIfUnsupported) const
1803 {
1804  IgnoreUnused(descriptor);
1805  bool supported = true;
1806  std::array<DataType,7> supportedTypes =
1807  {
1815  };
1816 
1817  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1818  "Reference Reduce: input type not supported");
1819 
1820  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1821  "Reference Reduce: output type not supported");
1822 
1823  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1824  "Reference Reduce: input and output types not matching");
1825 
1826  return supported;
1827 }
1828 
1830  const TensorInfo& output,
1831  const ReshapeDescriptor& descriptor,
1832  Optional<std::string&> reasonIfUnsupported) const
1833 {
1834  IgnoreUnused(output);
1835  IgnoreUnused(descriptor);
1836  // Define supported output types.
1837  std::array<DataType,8> supportedOutputTypes =
1838  {
1847  };
1848 
1849  return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1850  "Reference reshape: input type not supported.");
1851 }
1852 
1854  const TensorInfo& output,
1855  const ResizeDescriptor& descriptor,
1856  Optional<std::string&> reasonIfUnsupported) const
1857 {
1858  IgnoreUnused(descriptor);
1859  bool supported = true;
1860  std::array<DataType,6> supportedTypes =
1861  {
1868  };
1869 
1870  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1871  "Reference Resize: input type not supported");
1872 
1873  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1874  "Reference Resize: output type not supported");
1875 
1876  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1877  "Reference Resize: input and output types not matching");
1878 
1879  return supported;
1880 }
1881 
1883  const TensorInfo& output,
1884  Optional<std::string&> reasonIfUnsupported) const
1885 {
1886  IgnoreUnused(input);
1887  bool supported = true;
1888 
1889  std::array<DataType, 1> supportedTypes =
1890  {
1892  };
1893 
1894  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1895  "Reference Shape: output type not supported");
1896 
1897  return supported;
1898 }
1899 
1901  const TensorInfo& output,
1902  const SliceDescriptor& descriptor,
1903  Optional<std::string&> reasonIfUnsupported) const
1904 {
1905  IgnoreUnused(descriptor);
1906  bool supported = true;
1907 
1908  std::array<DataType, 5> supportedTypes =
1909  {
1915  };
1916 
1917  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1918  "Reference Slice: input type not supported");
1919 
1920  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1921  "Reference Slice: output type not supported");
1922 
1923  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1924  "Reference Slice: input and output types are mismatched");
1925 
1926  return supported;
1927 }
1928 
1930  const TensorInfo& output,
1931  const SoftmaxDescriptor& descriptor,
1932  Optional<std::string&> reasonIfUnsupported) const
1933 {
1934  IgnoreUnused(descriptor);
1935  bool supported = true;
1936  std::array<DataType,7> supportedTypes =
1937  {
1945  };
1946 
1947  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1948  "Reference Softmax: output type not supported");
1949 
1950  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1951  "Reference Softmax: input type not supported");
1952 
1953  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1954  "Reference Softmax: input type not supported");
1955 
1956  return supported;
1957 }
1958 
1960  const TensorInfo& output,
1961  const SpaceToBatchNdDescriptor& descriptor,
1962  Optional<std::string&> reasonIfUnsupported) const
1963 {
1964  IgnoreUnused(descriptor);
1965  bool supported = true;
1966  std::array<DataType,6> supportedTypes =
1967  {
1974  };
1975 
1976  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1977  "Reference SpaceToBatchNd: input type not supported");
1978 
1979  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1980  "Reference SpaceToBatchNd: output type not supported");
1981 
1982  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1983  "Reference SpaceToBatchNd: input and output types are mismatched");
1984 
1985  return supported;
1986 }
1987 
1989  const TensorInfo& output,
1990  const SpaceToDepthDescriptor& descriptor,
1991  Optional<std::string&> reasonIfUnsupported) const
1992 {
1993 
1994  IgnoreUnused(descriptor);
1995  bool supported = true;
1996 
1997  std::array<DataType,6> supportedTypes =
1998  {
2005  };
2006 
2007  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2008  "Reference SpaceToDepth: input type not supported");
2009 
2010  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2011  "Reference SpaceToDepth: output type not supported");
2012 
2013  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2014  "Reference SpaceToDepth: input and output types are mismatched");
2015 
2016  return supported;
2017 }
2018 
2020  const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2021  const ViewsDescriptor& descriptor,
2022  Optional<std::string&> reasonIfUnsupported) const
2023 {
2024  IgnoreUnused(descriptor);
2025  bool supported = true;
2026  std::array<DataType,6> supportedTypes =
2027  {
2034  };
2035 
2036  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2037  "Reference splitter: output type not supported");
2038  for (const TensorInfo& output : outputs)
2039  {
2040  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2041  "Reference splitter: input type not supported");
2042 
2043  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2044  "Reference splitter: input and output types mismatched.");
2045  }
2046 
2047  return supported;
2048 }
2049 
2050 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2051  const TensorInfo& output,
2052  const StackDescriptor& descriptor,
2053  Optional<std::string&> reasonIfUnsupported) const
2054 {
2055  IgnoreUnused(descriptor);
2056 
2057  bool supported = true;
2058  std::array<DataType,6> supportedTypes =
2059  {
2066  };
2067 
2068  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2069  "Reference stack: output type not supported");
2070  for (const TensorInfo* input : inputs)
2071  {
2072  ARMNN_ASSERT(input != nullptr);
2073  supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2074  "Reference stack: input type not supported");
2075 
2076  supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2077  "Reference stack: input and output types mismatched.");
2078  }
2079 
2080  return supported;
2081 }
2082 
2084  const TensorInfo& output,
2085  const StridedSliceDescriptor& descriptor,
2086  Optional<std::string&> reasonIfUnsupported) const
2087 {
2088  IgnoreUnused(descriptor);
2089  bool supported = true;
2090 
2091  std::array<DataType,5> supportedTypes =
2092  {
2098  };
2099 
2100  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2101  "Reference StridedSlice: input type not supported");
2102 
2103  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2104  "Reference StridedSlice: output type not supported");
2105 
2106  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2107  "Reference StridedSlice: input and output types are mismatched");
2108 
2109  return supported;
2110 }
2111 
2113  const TensorInfo& input1,
2114  const TensorInfo& output,
2115  Optional<std::string&> reasonIfUnsupported) const
2116 {
2117  bool supported = true;
2118 
2119  std::array<DataType,7> supportedTypes = {
2127  };
2128 
2129  supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2130  "Reference subtraction: input 0 is not a supported type.");
2131 
2132  supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2133  "Reference subtraction: input 1 is not a supported type.");
2134 
2135  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2136  "Reference subtraction: output is not a supported type.");
2137 
2138  supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2139  "Reference subtraction: input 0 and Input 1 types are mismatched");
2140 
2141  supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2142  "Reference subtraction: input and output types are mismatched");
2143 
2144  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2145  "Reference subtraction: shapes are not suitable for implicit broadcast.");
2146 
2147  return supported;
2148 }
2149 
2151  const TensorInfo& alpha,
2152  const TensorInfo& output,
2153  Optional<std::string&> reasonIfUnsupported) const
2154 {
2155  bool supported = true;
2156 
2157  std::array<DataType, 6> supportedTypes
2158  {
2165  };
2166 
2167  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2168  "PReLU: input is not a supported type.");
2169 
2170  supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2171  "PReLU: alpha is not a supported type.");
2172 
2173  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2174  "PReLU: output is not a supported type.");
2175 
2176  supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2177  "PReLU: input, alpha and output types are mismatched");
2178 
2179  supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2180  "PReLU: shapes are not suitable for implicit broadcast");
2181 
2182  return supported;
2183 }
2184 
2186  const TensorInfo& output,
2187  const TransposeConvolution2dDescriptor& descriptor,
2188  const TensorInfo& weights,
2189  const Optional<TensorInfo>& biases,
2190  Optional<std::string&> reasonIfUnsupported) const
2191 {
2192  IgnoreUnused(descriptor);
2193  bool supported = true;
2194 
2195  std::array<DataType,7> supportedTypes =
2196  {
2204  };
2205 
2206  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2207  "Reference TransposeConvolution2d: input is not a supported type.");
2208 
2209  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2210  "Reference TransposeConvolution2d: output is not a supported type.");
2211 
2212  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2213  "Reference TransposeConvolution2d: input and output types mismatched.");
2214 
2215 
2216  const DataType inputType = input.GetDataType();
2217  if (IsQuantized8BitType(inputType))
2218  {
2219  std::array<DataType, 3> supportedWeightTypes =
2220  {
2224  };
2225 
2226  supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2227  "Reference TransposeConvolution2d: weights type not supported for "
2228  "quantized input.");
2229  }
2230  else
2231  {
2232  supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2233  "Reference TransposeConvolution2d: weights is not a supported type.");
2234 
2235  supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2236  "Reference TransposeConvolution2d: input and weights types mismatched.");
2237  }
2238 
2239  if (biases.has_value())
2240  {
2241  std::array<DataType,4> biasesSupportedTypes =
2242  {
2247  };
2248  supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2249  "Reference TransposeConvolution2d: biases is not a supported type.");
2250  }
2251 
2252  return supported;
2253 }
2254 
2256  const TensorInfo& output,
2257  const TransposeDescriptor& descriptor,
2258  Optional<std::string&> reasonIfUnsupported) const
2259 {
2260  IgnoreUnused(descriptor);
2261  bool supported = true;
2262 
2263  // Define supported output and inputs types.
2264  std::array<DataType, 6> supportedTypes =
2265  {
2272  };
2273 
2274  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2275  "Reference transpose: input is not a supported type.");
2276 
2277  supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2278  "Reference transpose: output is not a supported type.");
2279 
2280  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2281  "Reference transpose: input and output types are mismatched.");
2282 
2283  return supported;
2284 }
2285 
2287  const TensorInfo& input,
2288  const TensorInfo& outputStateIn,
2289  const TensorInfo& cellStateIn,
2290  const TensorInfo& output,
2291  const Optional<TensorInfo>& hiddenStateOutput,
2292  const Optional<TensorInfo>& cellStateOutput,
2293  const UnidirectionalSequenceLstmDescriptor& descriptor,
2294  const LstmInputParamsInfo& paramsInfo,
2295  Optional<std::string&> reasonIfUnsupported) const
2296 {
2297  IgnoreUnused(descriptor);
2298  IgnoreUnused(paramsInfo);
2299  IgnoreUnused(outputStateIn);
2300  IgnoreUnused(cellStateIn);
2301  bool supported = true;
2302 
2303  if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2304  {
2305  reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2306  "and cell state output are not supported at the moment.";
2307  }
2308 
2309  std::array<DataType, 1> supportedTypes =
2310  {
2312  };
2313 
2314  std::array<DataType, 2> supportedWeightTypes =
2315  {
2318  };
2319 
2320  // check inputs and outputs
2321  supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2322  "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2323  supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2324  "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2325  supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2326  "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2327 
2328  supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2329  "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2330  // check layer parameters
2331  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2332  reasonIfUnsupported,
2333  "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2334  "is not a supported type.");
2335  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2336  reasonIfUnsupported,
2337  "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2338  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2339  reasonIfUnsupported,
2340  "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2341  "is not a supported type.");
2342  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2343  reasonIfUnsupported,
2344  "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2345  "is not a supported type.");
2346  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2347  reasonIfUnsupported,
2348  "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2349  "is not a supported type.");
2350  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2351  reasonIfUnsupported,
2352  "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2353  "is not a supported type.");
2354  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2355  "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2356  "are mismatched");
2357  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2358  "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2359  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2360  "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2361  "are mismatched");
2362  if (!descriptor.m_CifgEnabled)
2363  {
2364  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2365  reasonIfUnsupported,
2366  "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2367  "is not a supported type.");
2368  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2369  reasonIfUnsupported,
2370  "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2371  "is not a supported type.");
2372  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2373  "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2374  "are mismatched");
2375  if (descriptor.m_PeepholeEnabled)
2376  {
2377  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2378  reasonIfUnsupported,
2379  "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2380  "is not a supported type.");
2381  }
2382  }
2383  if (descriptor.m_PeepholeEnabled)
2384  {
2385  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2386  reasonIfUnsupported,
2387  "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2388  "is not a supported type.");
2389  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2390  reasonIfUnsupported,
2391  "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2392  "is not a supported type.");
2393  }
2394  if (descriptor.m_ProjectionEnabled)
2395  {
2396  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2397  reasonIfUnsupported,
2398  "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2399  "is not a supported type.");
2400  if (paramsInfo.m_ProjectionBias != nullptr)
2401  {
2402  supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2403  "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2404  "are mismatched");
2405  }
2406  }
2407  if (descriptor.m_LayerNormEnabled)
2408  {
2409  if (!descriptor.m_CifgEnabled)
2410  {
2411  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2412  reasonIfUnsupported,
2413  "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2414  "is not a supported type.");
2415  }
2416  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2417  reasonIfUnsupported,
2418  "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2419  "is not a supported type.");
2420  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2421  reasonIfUnsupported,
2422  "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2423  "is not a supported type.");
2424  supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2425  reasonIfUnsupported,
2426  "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2427  "is not a supported type.");
2428  }
2429 
2430  return supported;
2431 }
2432 
2433 } // namespace armnn
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool IsComparisonSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const ComparisonDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
UnaryOperation m_Operation
Specifies the elementwiseUnary operation to execute.
bool IsReshapeSupported(const TensorInfo &input, const TensorInfo &output, const ReshapeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ViewsDescriptor for the SplitterLayer.
const TensorInfo & GetRecurrentToCellWeights() const
Definition: LstmParams.hpp:145
bool IsSoftmaxSupported(const TensorInfo &input, const TensorInfo &output, const SoftmaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A TransposeConvolution2dDescriptor for the TransposeConvolution2dLayer.
bool IsPermuteSupported(const TensorInfo &input, const TensorInfo &output, const PermuteDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetCellBias() const
Definition: LstmParams.hpp:173
bool IsPadSupported(const TensorInfo &input, const TensorInfo &output, const PadDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsLogSoftmaxSupported(const TensorInfo &input, const TensorInfo &output, const LogSoftmaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported) const override
A ReshapeDescriptor for the ReshapeLayer.
const TensorInfo & GetRecurrentToInputWeights() const
Definition: LstmParams.hpp:137
const TensorInfo & GetCellLayerNormWeights() const
Definition: LstmParams.hpp:197
bool IsGatherSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const GatherDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvertFp32ToFp16Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ComparisonDescriptor for the ComparisonLayer.
Definition: Descriptors.hpp:78
const TensorInfo & GetRecurrentToOutputWeights() const
Definition: LstmParams.hpp:149
bool IsDilatedDepthwiseConvolutionSupported(const TensorInfo &input, const TensorInfo &output, const DepthwiseConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSliceSupported(const TensorInfo &input, const TensorInfo &output, const SliceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsStackSupported(const std::vector< const TensorInfo *> &inputs, const TensorInfo &output, const StackDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A Convolution2dDescriptor for the Convolution2dLayer.
bool IsPreluSupported(const TensorInfo &input, const TensorInfo &alpha, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvertBf16ToFp32Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSplitterSupported(const TensorInfo &input, const std::vector< std::reference_wrapper< TensorInfo >> &outputs, const ViewsDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetCellToInputWeights() const
Definition: LstmParams.hpp:153
bool IsLogicalBinarySupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, const LogicalBinaryDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported) const override
bool IsQLstmSupported(const TensorInfo &input, const TensorInfo &previousOutputIn, const TensorInfo &previousCellStateIn, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const QLstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsL2NormalizationSupported(const TensorInfo &input, const TensorInfo &output, const L2NormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsConvolution3dSupported(const TensorInfo &input, const TensorInfo &output, const Convolution3dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A LogicalBinaryDescriptor for the LogicalBinaryLayer.
bool IsDepthToSpaceSupported(const TensorInfo &input, const TensorInfo &output, const DepthToSpaceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsBatchNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const TensorInfo &mean, const TensorInfo &var, const TensorInfo &beta, const TensorInfo &gamma, const BatchNormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const TensorInfo & GetCellToForgetWeights() const
Definition: LstmParams.hpp:157
A SpaceToDepthDescriptor for the SpaceToDepthLayer.
bool IsDepthwiseConvolutionSupported(const TensorInfo &input, const TensorInfo &output, const DepthwiseConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFakeQuantizationSupported(const TensorInfo &input, const FakeQuantizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsBatchToSpaceNdSupported(const TensorInfo &input, const TensorInfo &output, const BatchToSpaceNdDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A BatchToSpaceNdDescriptor for the BatchToSpaceNdLayer.
const TensorInfo & GetForgetLayerNormWeights() const
Definition: LstmParams.hpp:193
const TensorInfo & GetCellToOutputWeights() const
Definition: LstmParams.hpp:161
bool IsTransposeConvolution2dSupported(const TensorInfo &input, const TensorInfo &output, const TransposeConvolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A ResizeBilinearDescriptor for the ResizeBilinearLayer.
std::vector< unsigned int > m_Axis
Values for the dimensions to reduce.
A StackDescriptor for the StackLayer.
constexpr bool IsQuantized8BitType(DataType dataType)
Definition: TypesUtils.hpp:285
bool IsOutputSupported(const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFullyConnectedSupported(const TensorInfo &input, const TensorInfo &output, const TensorInfo &weights, const TensorInfo &biases, const FullyConnectedDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetInputToCellWeights() const
Definition: LstmParams.hpp:129
bool IsQuantizeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsResizeSupported(const TensorInfo &input, const TensorInfo &output, const ResizeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A PadDescriptor for the PadLayer.
bool IsConstantSupported(const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
DataType
Definition: Types.hpp:35
bool IsArgMinMaxSupported(const TensorInfo &input, const TensorInfo &output, const ArgMinMaxDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSpaceToBatchNdSupported(const TensorInfo &input, const TensorInfo &output, const SpaceToBatchNdDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
An LstmDescriptor for the LstmLayer.
bool m_KeepDims
Enable/disable keep dimensions. If true, then the reduced dimensions that are of length 1 are kept...
bool IsMinimumSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetInputToOutputWeights() const
Definition: LstmParams.hpp:133
A L2NormalizationDescriptor for the L2NormalizationLayer.
An ArgMinMaxDescriptor for ArgMinMaxLayer.
Definition: Descriptors.hpp:56
DataType GetDataType() const
Definition: Tensor.hpp:198
An OriginsDescriptor for the ConcatLayer.
A ReduceDescriptor for the REDUCE operators.
bool has_value() const noexcept
Definition: Optional.hpp:53
A FullyConnectedDescriptor for the FullyConnectedLayer.
bool IsRankSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool m_BiasEnabled
Enable/disable bias.
A FakeQuantizationDescriptor for the FakeQuantizationLayer.
bool IsUnidirectionalSequenceLstmSupported(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &output, const Optional< TensorInfo > &hiddenStateOutput, const Optional< TensorInfo > &cellStateOutput, const UnidirectionalSequenceLstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo * m_ProjectionBias
Definition: LstmParams.hpp:105
bool IsConvertFp16ToFp32Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A GatherDescriptor for the GatherLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool IsMeanSupported(const TensorInfo &input, const TensorInfo &output, const MeanDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
A QLstmDescriptor for the QLstmLayer.
bool IsSpaceToDepthSupported(const TensorInfo &input, const TensorInfo &output, const SpaceToDepthDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsAdditionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsStridedSliceSupported(const TensorInfo &input, const TensorInfo &output, const StridedSliceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsSubtractionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
An ActivationDescriptor for the ActivationLayer.
Definition: Descriptors.hpp:25
bool IsConvertFp32ToBf16Supported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFloorSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsActivationSupported(const TensorInfo &input, const TensorInfo &output, const ActivationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
min(a, max(b, input)) ReLu1 & ReLu6.
const TensorInfo & GetRecurrentToForgetWeights() const
Definition: LstmParams.hpp:141
bool IsDebugSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A SliceDescriptor for the SliceLayer.
A Convolution3dDescriptor for the Convolution3dLayer.
bool IsDetectionPostProcessSupported(const TensorInfo &boxEncodings, const TensorInfo &scores, const TensorInfo &anchors, const TensorInfo &detectionBoxes, const TensorInfo &detectionClasses, const TensorInfo &detectionScores, const TensorInfo &numDetections, const DetectionPostProcessDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A SpaceToBatchNdDescriptor for the SpaceToBatchNdLayer.
const TensorInfo & GetInputToInputWeights() const
Definition: LstmParams.hpp:121
const TensorInfo & GetOutputLayerNormWeights() const
Definition: LstmParams.hpp:201
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
int32_t m_Axis
The axis in params to gather indices from.
A ElementwiseUnaryDescriptor for the ElementwiseUnaryLayer.
Definition: Descriptors.hpp:98
bool IsElementwiseUnarySupported(const TensorInfo &input, const TensorInfo &output, const ElementwiseUnaryDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsMultiplicationSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetForgetGateBias() const
Definition: LstmParams.hpp:169
bool IsConcatSupported(const std::vector< const TensorInfo *> inputs, const TensorInfo &output, const ConcatDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const NormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsDequantizeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A MeanDescriptor for the MeanLayer.
bool IsMemCopySupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool m_LayerNormEnabled
Enable/disable layer normalization.
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35
const TensorInfo & GetInputGateBias() const
Definition: LstmParams.hpp:165
A TransposeDescriptor for the TransposeLayer.
const TensorInfo & GetProjectionWeights() const
Definition: LstmParams.hpp:181
A StridedSliceDescriptor for the StridedSliceLayer.
const TensorInfo & GetInputToForgetWeights() const
Definition: LstmParams.hpp:125
bool IsInputSupported(const TensorInfo &input, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsChannelShuffleSupported(const TensorInfo &input, const TensorInfo &output, const ChannelShuffleDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsDivisionSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetInputLayerNormWeights() const
Definition: LstmParams.hpp:189
bool IsLstmSupported(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &scratchBuffer, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const LstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsFillSupported(const TensorInfo &input, const TensorInfo &output, const FillDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsShapeSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A Pooling2dDescriptor for the Pooling2dLayer.
A NormalizationDescriptor for the NormalizationLayer.
const TensorInfo & GetOutputGateBias() const
Definition: LstmParams.hpp:177
bool IsTransposeSupported(const TensorInfo &input, const TensorInfo &output, const TransposeDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsCastSupported(const TensorInfo &input, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
An InstanceNormalizationDescriptor for InstanceNormalizationLayer.
A ChannelShuffleDescriptor for the ChannelShuffle operator.
bool IsConvolution2dSupported(const TensorInfo &input, const TensorInfo &output, const Convolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
bool IsInstanceNormalizationSupported(const TensorInfo &input, const TensorInfo &output, const InstanceNormalizationDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
const TensorInfo & GetProjectionBias() const
Definition: LstmParams.hpp:185
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
bool IsSupportedForDataTypeGeneric(Optional< std::string &> reasonIfUnsupported, DataType dataType, Float16Func float16FuncPtr, Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, Int32Func int32FuncPtr, BooleanFunc booleanFuncPtr, Params &&... params)
A SoftmaxDescriptor for the SoftmaxLayer.
bool CheckSupportRule(F rule, Optional< std::string &> reasonIfUnsupported, const char *reason)
ActivationFunction m_Function
The activation function to use (Sigmoid, TanH, Linear, ReLu, BoundedReLu, SoftReLu, LeakyReLu, Abs, Sqrt, Square, Elu).
Definition: Descriptors.hpp:48
bool IsMaximumSupported(const TensorInfo &input0, const TensorInfo &input1, const TensorInfo &output, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A DepthwiseConvolution2dDescriptor for the DepthwiseConvolution2dLayer.
A FillDescriptor for the FillLayer.
bool IsPooling2dSupported(const TensorInfo &input, const TensorInfo &output, const Pooling2dDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A BatchNormalizationDescriptor for the BatchNormalizationLayer.
bool IsReduceSupported(const TensorInfo &input, const TensorInfo &output, const ReduceDescriptor &descriptor, Optional< std::string &> reasonIfUnsupported=EmptyOptional()) const override
A PermuteDescriptor for the PermuteLayer.