diff options
Diffstat (limited to 'src/armnn/backends/ClLayerSupport.hpp')
-rw-r--r-- | src/armnn/backends/ClLayerSupport.hpp | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/src/armnn/backends/ClLayerSupport.hpp b/src/armnn/backends/ClLayerSupport.hpp index 4f71e907cf..791e904616 100644 --- a/src/armnn/backends/ClLayerSupport.hpp +++ b/src/armnn/backends/ClLayerSupport.hpp @@ -7,16 +7,17 @@ #include <armnn/DescriptorsFwd.hpp> #include <armnn/Types.hpp> #include <armnn/Tensor.hpp> +#include <armnn/ArmNN.hpp> namespace armnn { bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc); -bool IsClActivationUint8Supported(std::string* reasonIfUnsupported, const ActivationDescriptor& parameters); bool IsClDepthwiseConvolution2dDescParamsSupported(std::string* reasonIfUnsupported, const DepthwiseConvolution2dDescriptor& parameters, const TensorInfo& weights); bool IsActivationSupportedCl(const TensorInfo& input, + const TensorInfo& output, const ActivationDescriptor& descriptor, std::string* reasonIfUnsupported = nullptr); @@ -26,6 +27,11 @@ bool IsAdditionSupportedCl(const TensorInfo& input0, std::string* reasonIfUnsupported = nullptr); bool IsBatchNormalizationSupportedCl(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& mean, + const TensorInfo& var, + const TensorInfo& beta, + const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, std::string* reasonIfUnsupported = nullptr); @@ -40,11 +46,16 @@ bool IsConvolution2dSupportedCl(const TensorInfo& input, std::string* reasonIfUnsupported = nullptr); bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input, + const TensorInfo& output, const DepthwiseConvolution2dDescriptor& descriptor, const TensorInfo& weights, + const TensorInfo& biases, std::string* reasonIfUnsupported = nullptr); bool IsFullyConnectedSupportedCl(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& weights, + const TensorInfo& biases, const FullyConnectedDescriptor& descriptor, std::string* reasonIfUnsupported = nullptr); @@ -52,14 +63,30 @@ bool IsInputSupportedCl(const TensorInfo& input, std::string* reasonIfUnsupported = nullptr); bool IsL2NormalizationSupportedCl(const TensorInfo& input, + const TensorInfo& output, std::string* reasonIfUnsupported = nullptr); +bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, + const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, + const TensorInfo& output, const LstmDescriptor& descriptor, + const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, + const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, + const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, + const TensorInfo& forgetGateBias, const TensorInfo& cellBias, + const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, + const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, + const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, + const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, + const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported = nullptr); + bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs, const OriginsDescriptor& descriptor, std::string* reasonIfUnsupported = nullptr); bool IsMultiplicationSupportedCl(const TensorInfo& input0, const TensorInfo& input1, + const TensorInfo& output, std::string* reasonIfUnsupported = nullptr); bool IsNormalizationSupportedCl(const TensorInfo& input, @@ -84,6 +111,7 @@ bool IsResizeBilinearSupportedCl(const TensorInfo& input, std::string* reasonIfUnsupported = nullptr); bool IsSoftmaxSupportedCl(const TensorInfo& input, + const TensorInfo& output, const SoftmaxDescriptor& descriptor, std::string* reasonIfUnsupported = nullptr); @@ -101,4 +129,13 @@ bool IsReshapeSupportedCl(const TensorInfo& input, bool IsFloorSupportedCl(const TensorInfo& input, const TensorInfo& output, std::string* reasonIfUnsupported = nullptr); + +bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported = nullptr); + +bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported = nullptr); + } |