From b5acbb77918df98debac200ebe082ce9aaab6a8c Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 15 Oct 2018 11:11:51 +0100 Subject: IVGCVSW-2004: Get rid of IsLayerSupportedRef functions in favor of ILayerSupport interface Change-Id: Ia147a0b408b2ca951c214963432d6e0f9b27b973 --- src/backends/reference/RefLayerSupport.cpp | 671 +++++++++-------------------- src/backends/reference/RefLayerSupport.hpp | 158 +------ 2 files changed, 195 insertions(+), 634 deletions(-) diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 3a250a6981..909df75445 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -5,26 +5,48 @@ #include "RefLayerSupport.hpp" -#include - -#include +#include +#include #include -#include #include -#include "InternalTypes.hpp" using namespace boost; namespace armnn { +namespace +{ + +template +bool IsSupportedForDataTypeRef(Optional reasonIfUnsupported, + DataType dataType, + Float32Func floatFuncPtr, + Uint8Func uint8FuncPtr, + Params&&... params) +{ + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + dataType, + &FalseFunc, + floatFuncPtr, + uint8FuncPtr, + std::forward(params)...); +} + +} // anonymous namespace + bool RefLayerSupport::IsActivationSupported(const TensorInfo& input, const TensorInfo& output, const ActivationDescriptor& descriptor, Optional reasonIfUnsupported) const { - return armnn::IsActivationSupportedRef(input, output, descriptor, reasonIfUnsupported); + ignore_unused(output); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, @@ -32,7 +54,12 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsAdditionSupportedRef(input0, input1, output, reasonIfUnsupported); + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, @@ -44,34 +71,57 @@ bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const BatchNormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { - return armnn::IsBatchNormalizationSupportedRef(input, - output, - mean, - var, - beta, - gamma, - descriptor, - reasonIfUnsupported); + ignore_unused(output); + ignore_unused(mean); + ignore_unused(var); + ignore_unused(beta); + ignore_unused(gamma); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsConstantSupportedRef(output, reasonIfUnsupported); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsConvertFp16ToFp32SupportedRef(input, output, reasonIfUnsupported); + return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseInputFuncF32<>, + &FalseFuncU8<>) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &FalseOutputFuncF16<>, + &TrueFunc<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsConvertFp32ToFp16SupportedRef(input, output, reasonIfUnsupported); + return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &FalseInputFuncF16<>, + &TrueFunc<>, + &FalseFuncU8<>) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &FalseOutputFuncF32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, @@ -81,12 +131,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, const Optional& biases, Optional reasonIfUnsupported) const { - return armnn::IsConvolution2dSupportedRef(input, - output, - descriptor, - weights, - biases, - reasonIfUnsupported); + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(weights); + ignore_unused(biases); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, @@ -96,12 +148,14 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, const Optional& biases, Optional reasonIfUnsupported) const { - return armnn::IsDepthwiseConvolutionSupportedRef(input, - output, - descriptor, - weights, - biases, - reasonIfUnsupported); + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(weights); + ignore_unused(biases); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0, @@ -109,21 +163,34 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0, const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsDivisionSupportedRef(input0, input1, output, reasonIfUnsupported); + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input, const FakeQuantizationDescriptor& descriptor, Optional reasonIfUnsupported) const { - return armnn::IsFakeQuantizationSupportedRef(input, descriptor, reasonIfUnsupported); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); } bool RefLayerSupport::IsFloorSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { - return armnn::IsFloorSupportedRef(input, output, reasonIfUnsupported); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); } bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, @@ -133,18 +200,23 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported) const { - return armnn::IsFullyConnectedSupportedRef(input, - output, - weights, - biases, - descriptor, - reasonIfUnsupported); + ignore_unused(output); + ignore_unused(weights); + ignore_unused(biases); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { - return armnn::IsInputSupportedRef(input, reasonIfUnsupported); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, @@ -152,7 +224,12 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { - return armnn::IsL2NormalizationSupportedRef(input, output, descriptor, reasonIfUnsupported); + ignore_unused(output); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); } bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, @@ -182,157 +259,39 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo* cellToOutputWeights, Optional reasonIfUnsupported) const { - return armnn::IsLstmSupportedRef(input, - outputStateIn, - cellStateIn, - scratchBuffer, - outputStateOut, - cellStateOut, - output, - descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reasonIfUnsupported); + ignore_unused(input); + ignore_unused(outputStateIn); + ignore_unused(cellStateIn); + ignore_unused(scratchBuffer); + ignore_unused(outputStateOut); + ignore_unused(cellStateOut); + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(inputToForgetWeights); + ignore_unused(inputToCellWeights); + ignore_unused(inputToOutputWeights); + ignore_unused(recurrentToForgetWeights); + ignore_unused(recurrentToCellWeights); + ignore_unused(recurrentToOutputWeights); + ignore_unused(forgetGateBias); + ignore_unused(cellBias); + ignore_unused(outputGateBias); + ignore_unused(inputToInputWeights); + ignore_unused(recurrentToInputWeights); + ignore_unused(cellToInputWeights); + ignore_unused(inputGateBias); + ignore_unused(projectionWeights); + ignore_unused(projectionBias); + ignore_unused(cellToForgetWeights); + ignore_unused(cellToOutputWeights); + ignore_unused(reasonIfUnsupported); + return false; } bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, Optional reasonIfUnsupported) const -{ - return armnn::IsMeanSupportedRef(input, output, descriptor,reasonIfUnsupported); -} - -bool RefLayerSupport::IsMergerSupported(const std::vector inputs, - const OriginsDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsMergerSupportedRef(inputs, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) const -{ - return armnn::IsMultiplicationSupportedRef(input0, input1, output, reasonIfUnsupported); -} - -bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, - const TensorInfo& output, - const NormalizationDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsNormalizationSupportedRef(input, - output, - descriptor, - reasonIfUnsupported); -} - -bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, - Optional reasonIfUnsupported) const -{ - return armnn::IsOutputSupportedRef(output, reasonIfUnsupported); -} - -bool RefLayerSupport::IsPadSupported(const TensorInfo& input, - const TensorInfo& output, - const PadDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsPadSupportedRef(input, output, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input, - const TensorInfo& output, - const PermuteDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsPermuteSupportedRef(input, output, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input, - const TensorInfo& output, - const Pooling2dDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsPooling2dSupportedRef(input, output, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input, - Optional reasonIfUnsupported) const -{ - return armnn::IsReshapeSupportedRef(input, reasonIfUnsupported); -} - -bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input, - Optional reasonIfUnsupported) const -{ - return armnn::IsResizeBilinearSupportedRef(input, reasonIfUnsupported); -} - -bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, - const TensorInfo& output, - const SoftmaxDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsSoftmaxSupportedRef(input, output, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input, - const ViewsDescriptor& descriptor, - Optional reasonIfUnsupported) const -{ - return armnn::IsSplitterSupportedRef(input, descriptor, reasonIfUnsupported); -} - -bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) const -{ - return armnn::IsSubtractionSupportedRef(input0, input1, output, reasonIfUnsupported); -} - -// -// Implementation functions -// -// TODO: Functions kept for backward compatibility. Remove once transition to plugable backends is complete! - -template -bool IsSupportedForDataTypeRef(Optional reasonIfUnsupported, - DataType dataType, - Float32Func floatFuncPtr, - Uint8Func uint8FuncPtr, - Params&&... params) -{ - return IsSupportedForDataTypeGeneric(reasonIfUnsupported, - dataType, - &FalseFunc, - floatFuncPtr, - uint8FuncPtr, - std::forward(params)...); -} - -bool IsActivationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const ActivationDescriptor& descriptor, - Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); @@ -342,95 +301,21 @@ bool IsActivationSupportedRef(const TensorInfo& input, &TrueFunc<>); } -bool IsAdditionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) -{ - ignore_unused(input1); - ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input0.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsBatchNormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const TensorInfo& mean, - const TensorInfo& var, - const TensorInfo& beta, - const TensorInfo& gamma, - const BatchNormalizationDescriptor& descriptor, - Optional reasonIfUnsupported) -{ - ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsConstantSupportedRef(const TensorInfo& output, - Optional reasonIfUnsupported) -{ - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsConvolution2dSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const Convolution2dDescriptor& descriptor, - const TensorInfo& weights, - const Optional& biases, - Optional reasonIfUnsupported) -{ - ignore_unused(descriptor); - ignore_unused(output); - ignore_unused(weights); - ignore_unused(biases); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const DepthwiseConvolution2dDescriptor& descriptor, - const TensorInfo& weights, - const Optional& biases, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsMergerSupported(const std::vector inputs, + const OriginsDescriptor& descriptor, + Optional reasonIfUnsupported) const { - ignore_unused(output); ignore_unused(descriptor); - ignore_unused(weights); - ignore_unused(biases); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsDivisionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) -{ - ignore_unused(input1); - ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, - input0.GetDataType(), + inputs[0]->GetDataType(), &TrueFunc<>, &TrueFunc<>); } -bool IsSubtractionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported) const { ignore_unused(input1); ignore_unused(output); @@ -440,95 +325,59 @@ bool IsSubtractionSupportedRef(const TensorInfo& input0, &TrueFunc<>); } -bool IsFullyConnectedSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const TensorInfo& weights, - const TensorInfo& biases, - const FullyConnectedDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, + const TensorInfo& output, + const NormalizationDescriptor& descriptor, + Optional reasonIfUnsupported) const { ignore_unused(output); ignore_unused(descriptor); - ignore_unused(weights); - ignore_unused(biases); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, - &TrueFunc<>); + &FalseFuncU8<>); } -bool IsInputSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, + Optional reasonIfUnsupported) const { return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), + output.GetDataType(), &TrueFunc<>, &TrueFunc<>); } -bool IsL2NormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const L2NormalizationDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsPadSupported(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + Optional reasonIfUnsupported) const { + ignore_unused(input); ignore_unused(output); ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); -} - -bool IsMergerSupportedRef(const std::vector inputs, - const OriginsDescriptor& descriptor, - Optional reasonIfUnsupported) -{ - ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - inputs[0]->GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + ignore_unused(reasonIfUnsupported); + return false; } -bool IsMultiplicationSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input, + const TensorInfo& output, + const PermuteDescriptor& descriptor, + Optional reasonIfUnsupported) const { - ignore_unused(input1); ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input0.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsNormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const NormalizationDescriptor& descriptor, - Optional reasonIfUnsupported) -{ ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, - &FalseFuncU8<>); -} - -bool IsOutputSupportedRef(const TensorInfo& output, - Optional reasonIfUnsupported) -{ - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, &TrueFunc<>); } -bool IsPermuteSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const PermuteDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input, + const TensorInfo& output, + const Pooling2dDescriptor& descriptor, + Optional reasonIfUnsupported) const { + ignore_unused(output); ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), @@ -536,20 +385,17 @@ bool IsPermuteSupportedRef(const TensorInfo& input, &TrueFunc<>); } -bool IsPooling2dSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const Pooling2dDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input, + Optional reasonIfUnsupported) const { - ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } -bool IsResizeBilinearSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input, + Optional reasonIfUnsupported) const { return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), @@ -557,10 +403,10 @@ bool IsResizeBilinearSupportedRef(const TensorInfo& input, &TrueFunc<>); } -bool IsSoftmaxSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const SoftmaxDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const SoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported) const { ignore_unused(output); ignore_unused(descriptor); @@ -570,9 +416,9 @@ bool IsSoftmaxSupportedRef(const TensorInfo& input, &TrueFunc<>); } -bool IsSplitterSupportedRef(const TensorInfo& input, - const ViewsDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input, + const ViewsDescriptor& descriptor, + Optional reasonIfUnsupported) const { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, @@ -581,146 +427,17 @@ bool IsSplitterSupportedRef(const TensorInfo& input, &TrueFunc<>); } -bool IsFakeQuantizationSupportedRef(const TensorInfo& input, - const FakeQuantizationDescriptor& descriptor, - Optional reasonIfUnsupported) -{ - ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); -} - -bool IsReshapeSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported) -{ - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); -} - -bool IsFloorSupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported) -{ - ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); -} - -bool IsLstmSupportedRef(const TensorInfo& input, - const TensorInfo& outputStateIn, - const TensorInfo& cellStateIn, - const TensorInfo& scratchBuffer, - const TensorInfo& outputStateOut, - const TensorInfo& cellStateOut, - const TensorInfo& output, - const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported) -{ - ignore_unused(input); - ignore_unused(outputStateIn); - ignore_unused(cellStateIn); - ignore_unused(scratchBuffer); - ignore_unused(outputStateOut); - ignore_unused(cellStateOut); - ignore_unused(output); - ignore_unused(descriptor); - ignore_unused(inputToForgetWeights); - ignore_unused(inputToCellWeights); - ignore_unused(inputToOutputWeights); - ignore_unused(recurrentToForgetWeights); - ignore_unused(recurrentToCellWeights); - ignore_unused(recurrentToOutputWeights); - ignore_unused(forgetGateBias); - ignore_unused(cellBias); - ignore_unused(outputGateBias); - ignore_unused(inputToInputWeights); - ignore_unused(recurrentToInputWeights); - ignore_unused(cellToInputWeights); - ignore_unused(inputGateBias); - ignore_unused(projectionWeights); - ignore_unused(projectionBias); - ignore_unused(cellToForgetWeights); - ignore_unused(cellToOutputWeights); - ignore_unused(reasonIfUnsupported); - return false; -} - -bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported) -{ - return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseInputFuncF32<>, - &FalseFuncU8<>) && - IsSupportedForDataTypeGeneric(reasonIfUnsupported, - output.GetDataType(), - &FalseOutputFuncF16<>, - &TrueFunc<>, - &FalseFuncU8<>)); -} - -bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported) -{ - return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, - input.GetDataType(), - &FalseInputFuncF16<>, - &TrueFunc<>, - &FalseFuncU8<>) && - IsSupportedForDataTypeGeneric(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &FalseOutputFuncF32<>, - &FalseFuncU8<>)); -} - -bool IsMeanSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const MeanDescriptor& descriptor, - Optional reasonIfUnsupported) +bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported) const { + ignore_unused(input1); ignore_unused(output); - ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), + input0.GetDataType(), &TrueFunc<>, &TrueFunc<>); } -bool IsPadSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const PadDescriptor& descriptor, - Optional reasonIfUnsupported) -{ - ignore_unused(output); - ignore_unused(descriptor); - return false; -} - } // namespace armnn diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 40bca7f179..0da59986a9 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -169,160 +169,4 @@ public: Optional reasonIfUnsupported = EmptyOptional()) const override; }; -bool IsActivationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const ActivationDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsAdditionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsBatchNormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const TensorInfo& mean, - const TensorInfo& var, - const TensorInfo& beta, - const TensorInfo& gamma, - const BatchNormalizationDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsConstantSupportedRef(const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsConvolution2dSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const Convolution2dDescriptor& descriptor, - const TensorInfo& weights, - const Optional& biases, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const DepthwiseConvolution2dDescriptor& descriptor, - const TensorInfo& weights, - const Optional& biases, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsDivisionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsSubtractionSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsFullyConnectedSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const TensorInfo& weights, - const TensorInfo& biases, - const FullyConnectedDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsInputSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsL2NormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const L2NormalizationDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsLstmSupportedRef(const TensorInfo& input, - const TensorInfo& outputStateIn, - const TensorInfo& cellStateIn, - const TensorInfo& scratchBuffer, - const TensorInfo& outputStateOut, - const TensorInfo& cellStateOut, - const TensorInfo& output, - const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsMergerSupportedRef(const std::vector inputs, - const OriginsDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsMultiplicationSupportedRef(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsNormalizationSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const NormalizationDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsOutputSupportedRef(const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsPermuteSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const PermuteDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsPooling2dSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const Pooling2dDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsResizeBilinearSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsSoftmaxSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const SoftmaxDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsSplitterSupportedRef(const TensorInfo& input, - const ViewsDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsFakeQuantizationSupportedRef(const TensorInfo& input, - const FakeQuantizationDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsReshapeSupportedRef(const TensorInfo& input, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsFloorSupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input, - const TensorInfo& output, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsMeanSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const MeanDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -bool IsPadSupportedRef(const TensorInfo& input, - const TensorInfo& output, - const PadDescriptor& descriptor, - Optional reasonIfUnsupported = EmptyOptional()); - -} +} // namespace armnn -- cgit v1.2.1