aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2018-10-15 11:11:51 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:53 +0100
commitb5acbb77918df98debac200ebe082ce9aaab6a8c (patch)
treec42599649a5ea50a198840c1dfa6462a98c00fce
parent3cc9a626773ae9e79d3d0bd9c120704676d44daa (diff)
downloadarmnn-b5acbb77918df98debac200ebe082ce9aaab6a8c.tar.gz
IVGCVSW-2004: Get rid of IsLayerSupportedRef functions in favor of ILayerSupport interface
Change-Id: Ia147a0b408b2ca951c214963432d6e0f9b27b973
-rw-r--r--src/backends/reference/RefLayerSupport.cpp671
-rw-r--r--src/backends/reference/RefLayerSupport.hpp158
2 files changed, 195 insertions, 634 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 3a250a6981..909df75445 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -5,26 +5,48 @@
#include "RefLayerSupport.hpp"
-#include <LayerSupportCommon.hpp>
-
-#include <armnn/Descriptors.hpp>
+#include <armnn/InternalTypes.hpp>
+#include <armnn/LayerSupportCommon.hpp>
#include <armnn/Types.hpp>
-#include <armnn/Tensor.hpp>
#include <boost/core/ignore_unused.hpp>
-#include "InternalTypes.hpp"
using namespace boost;
namespace armnn
{
+namespace
+{
+
+template<typename Float32Func, typename Uint8Func, typename ... Params>
+bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
+ DataType dataType,
+ Float32Func floatFuncPtr,
+ Uint8Func uint8FuncPtr,
+ Params&&... params)
+{
+ return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ dataType,
+ &FalseFunc<Params...>,
+ floatFuncPtr,
+ uint8FuncPtr,
+ std::forward<Params>(params)...);
+}
+
+} // anonymous namespace
+
bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
const TensorInfo& output,
const ActivationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsActivationSupportedRef(input, output, descriptor, reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
@@ -32,7 +54,12 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsAdditionSupportedRef(input0, input1, output, reasonIfUnsupported);
+ ignore_unused(input1);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
@@ -44,34 +71,57 @@ bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
const BatchNormalizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsBatchNormalizationSupportedRef(input,
- output,
- mean,
- var,
- beta,
- gamma,
- descriptor,
- reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(mean);
+ ignore_unused(var);
+ ignore_unused(beta);
+ ignore_unused(gamma);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsConstantSupportedRef(output, reasonIfUnsupported);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ output.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsConvertFp16ToFp32SupportedRef(input, output, reasonIfUnsupported);
+ return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &FalseInputFuncF32<>,
+ &FalseFuncU8<>) &&
+ IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ output.GetDataType(),
+ &FalseOutputFuncF16<>,
+ &TrueFunc<>,
+ &FalseFuncU8<>));
}
bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsConvertFp32ToFp16SupportedRef(input, output, reasonIfUnsupported);
+ return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ input.GetDataType(),
+ &FalseInputFuncF16<>,
+ &TrueFunc<>,
+ &FalseFuncU8<>) &&
+ IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ output.GetDataType(),
+ &TrueFunc<>,
+ &FalseOutputFuncF32<>,
+ &FalseFuncU8<>));
}
bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
@@ -81,12 +131,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsConvolution2dSupportedRef(input,
- output,
- descriptor,
- weights,
- biases,
- reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ ignore_unused(weights);
+ ignore_unused(biases);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
@@ -96,12 +148,14 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsDepthwiseConvolutionSupportedRef(input,
- output,
- descriptor,
- weights,
- biases,
- reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ ignore_unused(weights);
+ ignore_unused(biases);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
@@ -109,21 +163,34 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsDivisionSupportedRef(input0, input1, output, reasonIfUnsupported);
+ ignore_unused(input1);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsFakeQuantizationSupportedRef(input, descriptor, reasonIfUnsupported);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &FalseFuncU8<>);
}
bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsFloorSupportedRef(input, output, reasonIfUnsupported);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &FalseFuncU8<>);
}
bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
@@ -133,18 +200,23 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsFullyConnectedSupportedRef(input,
- output,
- weights,
- biases,
- descriptor,
- reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(weights);
+ ignore_unused(biases);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsInputSupportedRef(input, reasonIfUnsupported);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
@@ -152,7 +224,12 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
const L2NormalizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsL2NormalizationSupportedRef(input, output, descriptor, reasonIfUnsupported);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &FalseFuncU8<>);
}
bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
@@ -182,32 +259,33 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo* cellToOutputWeights,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsLstmSupportedRef(input,
- outputStateIn,
- cellStateIn,
- scratchBuffer,
- outputStateOut,
- cellStateOut,
- output,
- descriptor,
- inputToForgetWeights,
- inputToCellWeights,
- inputToOutputWeights,
- recurrentToForgetWeights,
- recurrentToCellWeights,
- recurrentToOutputWeights,
- forgetGateBias,
- cellBias,
- outputGateBias,
- inputToInputWeights,
- recurrentToInputWeights,
- cellToInputWeights,
- inputGateBias,
- projectionWeights,
- projectionBias,
- cellToForgetWeights,
- cellToOutputWeights,
- reasonIfUnsupported);
+ ignore_unused(input);
+ ignore_unused(outputStateIn);
+ ignore_unused(cellStateIn);
+ ignore_unused(scratchBuffer);
+ ignore_unused(outputStateOut);
+ ignore_unused(cellStateOut);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ ignore_unused(inputToForgetWeights);
+ ignore_unused(inputToCellWeights);
+ ignore_unused(inputToOutputWeights);
+ ignore_unused(recurrentToForgetWeights);
+ ignore_unused(recurrentToCellWeights);
+ ignore_unused(recurrentToOutputWeights);
+ ignore_unused(forgetGateBias);
+ ignore_unused(cellBias);
+ ignore_unused(outputGateBias);
+ ignore_unused(inputToInputWeights);
+ ignore_unused(recurrentToInputWeights);
+ ignore_unused(cellToInputWeights);
+ ignore_unused(inputGateBias);
+ ignore_unused(projectionWeights);
+ ignore_unused(projectionBias);
+ ignore_unused(cellToForgetWeights);
+ ignore_unused(cellToOutputWeights);
+ ignore_unused(reasonIfUnsupported);
+ return false;
}
bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
@@ -215,125 +293,6 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
const MeanDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- return armnn::IsMeanSupportedRef(input, output, descriptor,reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
- const OriginsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsMergerSupportedRef(inputs, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsMultiplicationSupportedRef(input0, input1, output, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
- const TensorInfo& output,
- const NormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsNormalizationSupportedRef(input,
- output,
- descriptor,
- reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsOutputSupportedRef(output, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
- const TensorInfo& output,
- const PadDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsPadSupportedRef(input, output, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
- const TensorInfo& output,
- const PermuteDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsPermuteSupportedRef(input, output, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
- const TensorInfo& output,
- const Pooling2dDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsPooling2dSupportedRef(input, output, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsReshapeSupportedRef(input, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsResizeBilinearSupportedRef(input, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
- const TensorInfo& output,
- const SoftmaxDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsSoftmaxSupportedRef(input, output, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
- const ViewsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsSplitterSupportedRef(input, descriptor, reasonIfUnsupported);
-}
-
-bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported) const
-{
- return armnn::IsSubtractionSupportedRef(input0, input1, output, reasonIfUnsupported);
-}
-
-//
-// Implementation functions
-//
-// TODO: Functions kept for backward compatibility. Remove once transition to plugable backends is complete!
-
-template<typename Float32Func, typename Uint8Func, typename ... Params>
-bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
- DataType dataType,
- Float32Func floatFuncPtr,
- Uint8Func uint8FuncPtr,
- Params&&... params)
-{
- return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- dataType,
- &FalseFunc<Params...>,
- floatFuncPtr,
- uint8FuncPtr,
- std::forward<Params>(params)...);
-}
-
-bool IsActivationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const ActivationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
ignore_unused(output);
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
@@ -342,95 +301,21 @@ bool IsActivationSupportedRef(const TensorInfo& input,
&TrueFunc<>);
}
-bool IsAdditionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const TensorInfo& mean,
- const TensorInfo& var,
- const TensorInfo& beta,
- const TensorInfo& gamma,
- const BatchNormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsConstantSupportedRef(const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- output.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsConvolution2dSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const Convolution2dDescriptor& descriptor,
- const TensorInfo& weights,
- const Optional<TensorInfo>& biases,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(descriptor);
- ignore_unused(output);
- ignore_unused(weights);
- ignore_unused(biases);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const DepthwiseConvolution2dDescriptor& descriptor,
- const TensorInfo& weights,
- const Optional<TensorInfo>& biases,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
+ const OriginsDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
ignore_unused(descriptor);
- ignore_unused(weights);
- ignore_unused(biases);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsDivisionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(input1);
- ignore_unused(output);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
+ inputs[0]->GetDataType(),
&TrueFunc<>,
&TrueFunc<>);
}
-bool IsSubtractionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(input1);
ignore_unused(output);
@@ -440,95 +325,59 @@ bool IsSubtractionSupportedRef(const TensorInfo& input0,
&TrueFunc<>);
}
-bool IsFullyConnectedSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const TensorInfo& weights,
- const TensorInfo& biases,
- const FullyConnectedDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const NormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(output);
ignore_unused(descriptor);
- ignore_unused(weights);
- ignore_unused(biases);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
- &TrueFunc<>);
+ &FalseFuncU8<>);
}
-bool IsInputSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
{
return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
+ output.GetDataType(),
&TrueFunc<>,
&TrueFunc<>);
}
-bool IsL2NormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const L2NormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const PadDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
+ ignore_unused(input);
ignore_unused(output);
ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
-}
-
-bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
- const OriginsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- inputs[0]->GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ ignore_unused(reasonIfUnsupported);
+ return false;
}
-bool IsMultiplicationSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const PermuteDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsNormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const NormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
- &FalseFuncU8<>);
-}
-
-bool IsOutputSupportedRef(const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- output.GetDataType(),
- &TrueFunc<>,
&TrueFunc<>);
}
-bool IsPermuteSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const PermuteDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const Pooling2dDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
+ ignore_unused(output);
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
@@ -536,20 +385,17 @@ bool IsPermuteSupportedRef(const TensorInfo& input,
&TrueFunc<>);
}
-bool IsPooling2dSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const Pooling2dDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
+ Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
&TrueFunc<>);
}
-bool IsResizeBilinearSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
+ Optional<std::string&> reasonIfUnsupported) const
{
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
@@ -557,10 +403,10 @@ bool IsResizeBilinearSupportedRef(const TensorInfo& input,
&TrueFunc<>);
}
-bool IsSoftmaxSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const SoftmaxDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const SoftmaxDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(output);
ignore_unused(descriptor);
@@ -570,9 +416,9 @@ bool IsSoftmaxSupportedRef(const TensorInfo& input,
&TrueFunc<>);
}
-bool IsSplitterSupportedRef(const TensorInfo& input,
- const ViewsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
+ const ViewsDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
@@ -581,146 +427,17 @@ bool IsSplitterSupportedRef(const TensorInfo& input,
&TrueFunc<>);
}
-bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
- const FakeQuantizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
-}
-
-bool IsReshapeSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported)
-{
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
-}
-
-bool IsFloorSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
-}
-
-bool IsLstmSupportedRef(const TensorInfo& input,
- const TensorInfo& outputStateIn,
- const TensorInfo& cellStateIn,
- const TensorInfo& scratchBuffer,
- const TensorInfo& outputStateOut,
- const TensorInfo& cellStateOut,
- const TensorInfo& output,
- const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(input);
- ignore_unused(outputStateIn);
- ignore_unused(cellStateIn);
- ignore_unused(scratchBuffer);
- ignore_unused(outputStateOut);
- ignore_unused(cellStateOut);
- ignore_unused(output);
- ignore_unused(descriptor);
- ignore_unused(inputToForgetWeights);
- ignore_unused(inputToCellWeights);
- ignore_unused(inputToOutputWeights);
- ignore_unused(recurrentToForgetWeights);
- ignore_unused(recurrentToCellWeights);
- ignore_unused(recurrentToOutputWeights);
- ignore_unused(forgetGateBias);
- ignore_unused(cellBias);
- ignore_unused(outputGateBias);
- ignore_unused(inputToInputWeights);
- ignore_unused(recurrentToInputWeights);
- ignore_unused(cellToInputWeights);
- ignore_unused(inputGateBias);
- ignore_unused(projectionWeights);
- ignore_unused(projectionBias);
- ignore_unused(cellToForgetWeights);
- ignore_unused(cellToOutputWeights);
- ignore_unused(reasonIfUnsupported);
- return false;
-}
-
-bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseInputFuncF32<>,
- &FalseFuncU8<>) &&
- IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- output.GetDataType(),
- &FalseOutputFuncF16<>,
- &TrueFunc<>,
- &FalseFuncU8<>));
-}
-
-bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported)
-{
- return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- input.GetDataType(),
- &FalseInputFuncF16<>,
- &TrueFunc<>,
- &FalseFuncU8<>) &&
- IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- output.GetDataType(),
- &TrueFunc<>,
- &FalseOutputFuncF32<>,
- &FalseFuncU8<>));
-}
-
-bool IsMeanSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const MeanDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
+bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
{
+ ignore_unused(input1);
ignore_unused(output);
- ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
+ input0.GetDataType(),
&TrueFunc<>,
&TrueFunc<>);
}
-bool IsPadSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const PadDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported)
-{
- ignore_unused(output);
- ignore_unused(descriptor);
- return false;
-}
-
} // namespace armnn
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 40bca7f179..0da59986a9 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -169,160 +169,4 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
};
-bool IsActivationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const ActivationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsAdditionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const TensorInfo& mean,
- const TensorInfo& var,
- const TensorInfo& beta,
- const TensorInfo& gamma,
- const BatchNormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsConstantSupportedRef(const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsConvolution2dSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const Convolution2dDescriptor& descriptor,
- const TensorInfo& weights,
- const Optional<TensorInfo>& biases,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const DepthwiseConvolution2dDescriptor& descriptor,
- const TensorInfo& weights,
- const Optional<TensorInfo>& biases,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsDivisionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsSubtractionSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsFullyConnectedSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const TensorInfo& weights,
- const TensorInfo& biases,
- const FullyConnectedDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsInputSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsL2NormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const L2NormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsLstmSupportedRef(const TensorInfo& input,
- const TensorInfo& outputStateIn,
- const TensorInfo& cellStateIn,
- const TensorInfo& scratchBuffer,
- const TensorInfo& outputStateOut,
- const TensorInfo& cellStateOut,
- const TensorInfo& output,
- const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
- const OriginsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsMultiplicationSupportedRef(const TensorInfo& input0,
- const TensorInfo& input1,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsNormalizationSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const NormalizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsOutputSupportedRef(const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsPermuteSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const PermuteDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsPooling2dSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const Pooling2dDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsResizeBilinearSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsSoftmaxSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const SoftmaxDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsSplitterSupportedRef(const TensorInfo& input,
- const ViewsDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
- const FakeQuantizationDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsReshapeSupportedRef(const TensorInfo& input,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsFloorSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsMeanSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const MeanDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-bool IsPadSupportedRef(const TensorInfo& input,
- const TensorInfo& output,
- const PadDescriptor& descriptor,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional());
-
-}
+} // namespace armnn