diff options
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 203 |
1 files changed, 123 insertions, 80 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 434b069092..494b339952 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -22,16 +22,16 @@ #include "workloads/ClConvolution2dWorkload.hpp" #include "workloads/ClDepthwiseConvolutionWorkload.hpp" #include "workloads/ClDivisionFloatWorkload.hpp" +#include "workloads/ClFullyConnectedWorkload.hpp" #include "workloads/ClL2NormalizationFloatWorkload.hpp" +#include "workloads/ClLstmFloatWorkload.hpp" #include "workloads/ClMultiplicationWorkload.hpp" -#include "workloads/ClFullyConnectedWorkload.hpp" +#include "workloads/ClNormalizationFloatWorkload.hpp" #include "workloads/ClPadWorkload.hpp" -#include "workloads/ClPooling2dBaseWorkload.hpp" #include "workloads/ClPermuteWorkload.hpp" -#include "workloads/ClNormalizationFloatWorkload.hpp" +#include "workloads/ClPooling2dBaseWorkload.hpp" #include "workloads/ClSoftmaxBaseWorkload.hpp" #include "workloads/ClSubtractionWorkload.hpp" -#include "workloads/ClLstmFloatWorkload.hpp" #endif using namespace boost; @@ -59,14 +59,14 @@ bool IsMatchingStride(uint32_t actualStride) return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride); }; -bool IsClBackendSupported(std::string* reasonIfUnsupported) +bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported) { #if ARMCOMPUTECL_ENABLED return true; #else - if (reasonIfUnsupported != nullptr) + if (reasonIfUnsupported) { - *reasonIfUnsupported = "The armnn library has been built without CL support"; + reasonIfUnsupported.value() = "The armnn library has been built without CL support"; } return false; #endif @@ -80,13 +80,13 @@ bool IsClBackendSupported(std::string* reasonIfUnsupported) #if ARMCOMPUTECL_ENABLED template<class FuncType, class... Args> -inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args) +inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args) { arm_compute::Status aclStatus = func(std::forward<Args>(args)...); const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); if (!supported && reasonIfUnsupported) { - *reasonIfUnsupported = aclStatus.error_description(); + reasonIfUnsupported.value() = aclStatus.error_description(); } return supported; } @@ -101,7 +101,7 @@ inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupporte } //namespace template<typename FloatFunc, typename Uint8Func, typename ... Params> -bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported, +bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported, DataType dataType, FloatFunc floatFuncPtr, Uint8Func uint8FuncPtr, @@ -119,7 +119,7 @@ bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported, bool IsActivationSupportedCl(const TensorInfo& input, const TensorInfo& output, const ActivationDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate, reasonIfUnsupported, @@ -131,12 +131,13 @@ bool IsActivationSupportedCl(const TensorInfo& input, bool IsAdditionSupportedCl(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { - return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0, - input1, - output, - reasonIfUnsupported)); + FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate, + reasonIfUnsupported, + input0, + input1, + output); } bool IsBatchNormalizationSupportedCl(const TensorInfo& input, @@ -146,7 +147,7 @@ bool IsBatchNormalizationSupportedCl(const TensorInfo& input, const TensorInfo& beta, const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate, reasonIfUnsupported, @@ -160,7 +161,7 @@ bool IsBatchNormalizationSupportedCl(const TensorInfo& input, } bool IsConstantSupportedCl(const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { return IsSupportedForDataTypeCl(reasonIfUnsupported, output.GetDataType(), @@ -201,10 +202,11 @@ bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convol return isSupported; } -bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported, +bool IsDirectConvolution2dParamsSupportedCl(Optional<std::string&> reasonIfUnsupported, const Convolution2dDescriptor& parameters, const TensorInfo& weightInfo) { + ignore_unused(reasonIfUnsupported); return IsClDirectConvolution2dSupported(weightInfo, parameters); } @@ -213,7 +215,7 @@ bool IsConvolution2dSupportedCl(const TensorInfo& input, const Convolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional<TensorInfo>& biases, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate, reasonIfUnsupported, @@ -229,7 +231,7 @@ bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input, const DepthwiseConvolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional<TensorInfo>& biases, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate, reasonIfUnsupported, @@ -243,7 +245,7 @@ bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input, bool IsDivisionSupportedCl(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate, reasonIfUnsupported, @@ -255,12 +257,14 @@ bool IsDivisionSupportedCl(const TensorInfo& input0, bool IsSubtractionSupportedCl(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { - return FORWARD_CL_LAYER_SUPPORT_FUNC(ClSubtractionValidate(input0, - input1, - output, - reasonIfUnsupported)); + + FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate, + reasonIfUnsupported, + input0, + input1, + output); } bool IsFullyConnectedSupportedCl(const TensorInfo& input, @@ -268,7 +272,7 @@ bool IsFullyConnectedSupportedCl(const TensorInfo& input, const TensorInfo& weights, const TensorInfo& biases, const FullyConnectedDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate, reasonIfUnsupported, @@ -280,7 +284,7 @@ bool IsFullyConnectedSupportedCl(const TensorInfo& input, } bool IsInputSupportedCl(const TensorInfo& input, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { return IsSupportedForDataTypeCl(reasonIfUnsupported, input.GetDataType(), @@ -291,14 +295,14 @@ bool IsInputSupportedCl(const TensorInfo& input, bool IsL2NormalizationSupportedCl(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs, const OriginsDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeCl(reasonIfUnsupported, @@ -310,7 +314,7 @@ bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs, bool IsMultiplicationSupportedCl(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate, reasonIfUnsupported, @@ -322,13 +326,13 @@ bool IsMultiplicationSupportedCl(const TensorInfo& input0, bool IsNormalizationSupportedCl(const TensorInfo& input, const TensorInfo& output, const NormalizationDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } bool IsOutputSupportedCl(const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { return IsSupportedForDataTypeCl(reasonIfUnsupported, output.GetDataType(), @@ -336,18 +340,10 @@ bool IsOutputSupportedCl(const TensorInfo& output, &TrueFunc<>); } -bool IsPadSupportedCl(const TensorInfo& input, - const TensorInfo& output, - const PadDescriptor& descriptor, - std::string* reasonIfUnsupported) -{ - return FORWARD_CL_LAYER_SUPPORT_FUNC(ClPadValidate(input, output, descriptor, reasonIfUnsupported)); -} - bool IsPermuteSupportedCl(const TensorInfo& input, const TensorInfo& output, const PermuteDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(input); ignore_unused(output); @@ -357,13 +353,13 @@ bool IsPermuteSupportedCl(const TensorInfo& input, bool IsPooling2dSupportedCl(const TensorInfo& input, const TensorInfo& output, const Pooling2dDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } bool IsResizeBilinearSupportedCl(const TensorInfo& input, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { return IsSupportedForDataTypeCl(reasonIfUnsupported, input.GetDataType(), @@ -374,7 +370,7 @@ bool IsResizeBilinearSupportedCl(const TensorInfo& input, bool IsSoftmaxSupportedCl(const TensorInfo& input, const TensorInfo& output, const SoftmaxDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(descriptor); FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output); @@ -382,7 +378,7 @@ bool IsSoftmaxSupportedCl(const TensorInfo& input, bool IsSplitterSupportedCl(const TensorInfo& input, const ViewsDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeCl(reasonIfUnsupported, @@ -393,23 +389,25 @@ bool IsSplitterSupportedCl(const TensorInfo& input, bool IsFakeQuantizationSupportedCl(const TensorInfo& input, const FakeQuantizationDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(input); ignore_unused(descriptor); + ignore_unused(reasonIfUnsupported); return false; } bool IsReshapeSupportedCl(const TensorInfo& input, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(input); + ignore_unused(reasonIfUnsupported); return true; } bool IsFloorSupportedCl(const TensorInfo& input, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { ignore_unused(output); return IsClBackendSupported(reasonIfUnsupported) && @@ -420,59 +418,104 @@ bool IsFloorSupportedCl(const TensorInfo& input, &FalseFuncU8<>); } -bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn, - const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, - const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, - const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported) -{ - FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported, - input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut, - output, descriptor, inputToForgetWeights, inputToCellWeights, - inputToOutputWeights, recurrentToForgetWeights, - recurrentToCellWeights, recurrentToOutputWeights, - forgetGateBias, cellBias, outputGateBias, - inputToInputWeights, recurrentToInputWeights, - cellToInputWeights, inputGateBias, projectionWeights, - projectionBias, cellToForgetWeights, cellToOutputWeights); +bool IsLstmSupportedCl(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& scratchBuffer, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const LstmDescriptor& descriptor, + const TensorInfo& inputToForgetWeights, + const TensorInfo& inputToCellWeights, + const TensorInfo& inputToOutputWeights, + const TensorInfo& recurrentToForgetWeights, + const TensorInfo& recurrentToCellWeights, + const TensorInfo& recurrentToOutputWeights, + const TensorInfo& forgetGateBias, + const TensorInfo& cellBias, + const TensorInfo& outputGateBias, + const TensorInfo* inputToInputWeights, + const TensorInfo* recurrentToInputWeights, + const TensorInfo* cellToInputWeights, + const TensorInfo* inputGateBias, + const TensorInfo* projectionWeights, + const TensorInfo* projectionBias, + const TensorInfo* cellToForgetWeights, + const TensorInfo* cellToOutputWeights, + Optional<std::string&> reasonIfUnsupported) +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, + reasonIfUnsupported, + input, + outputStateIn, + cellStateIn, + scratchBuffer, + outputStateOut, + cellStateOut, + output, + descriptor, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights); } bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate, reasonIfUnsupported, input, - output, - reasonIfUnsupported); + output); } bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input, const TensorInfo& output, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate, reasonIfUnsupported, input, - output, - reasonIfUnsupported); + output); } bool IsMeanSupportedCl(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, - std::string* reasonIfUnsupported) + Optional<std::string&> reasonIfUnsupported) { + ignore_unused(input); + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(reasonIfUnsupported); return false; } +bool IsPadSupportedCl(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate, + reasonIfUnsupported, + input, + output, + descriptor); +} + } |