aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefLayerSupport.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefLayerSupport.hpp')
-rw-r--r--src/armnn/backends/RefLayerSupport.hpp38
1 files changed, 38 insertions, 0 deletions
diff --git a/src/armnn/backends/RefLayerSupport.hpp b/src/armnn/backends/RefLayerSupport.hpp
index 9db1c14596..5e543ac537 100644
--- a/src/armnn/backends/RefLayerSupport.hpp
+++ b/src/armnn/backends/RefLayerSupport.hpp
@@ -7,11 +7,14 @@
#include <armnn/DescriptorsFwd.hpp>
#include <armnn/Types.hpp>
#include <armnn/Tensor.hpp>
+#include <layers/LstmLayer.hpp>
+#include <boost/optional.hpp>
namespace armnn
{
bool IsActivationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const ActivationDescriptor& descriptor,
std::string* reasonIfUnsupported = nullptr);
@@ -21,6 +24,11 @@ bool IsAdditionSupportedRef(const TensorInfo& input0,
std::string* reasonIfUnsupported = nullptr);
bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& mean,
+ const TensorInfo& var,
+ const TensorInfo& beta,
+ const TensorInfo& gamma,
const BatchNormalizationDescriptor& descriptor,
std::string* reasonIfUnsupported = nullptr);
@@ -35,11 +43,16 @@ bool IsConvolution2dSupportedRef(const TensorInfo& input,
std::string* reasonIfUnsupported = nullptr);
bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const DepthwiseConvolution2dDescriptor& descriptor,
const TensorInfo& weights,
+ const TensorInfo& biases,
std::string* reasonIfUnsupported = nullptr);
bool IsFullyConnectedSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& weights,
+ const TensorInfo& biases,
const FullyConnectedDescriptor& descriptor,
std::string* reasonIfUnsupported = nullptr);
@@ -47,14 +60,30 @@ bool IsInputSupportedRef(const TensorInfo& input,
std::string* reasonIfUnsupported = nullptr);
bool IsL2NormalizationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
std::string* reasonIfUnsupported = nullptr);
+bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
+ const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
+ const TensorInfo& output, const LstmDescriptor& descriptor,
+ const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
+ const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
+ const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
+ const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
+ const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
+ const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
+ const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
+ const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
+ const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported = nullptr);
+
bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
const OriginsDescriptor& descriptor,
std::string* reasonIfUnsupported = nullptr);
bool IsMultiplicationSupportedRef(const TensorInfo& input0,
const TensorInfo& input1,
+ const TensorInfo& output,
std::string* reasonIfUnsupported = nullptr);
bool IsNormalizationSupportedRef(const TensorInfo& input,
@@ -79,6 +108,7 @@ bool IsResizeBilinearSupportedRef(const TensorInfo& input,
std::string* reasonIfUnsupported = nullptr);
bool IsSoftmaxSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const SoftmaxDescriptor& descriptor,
std::string* reasonIfUnsupported = nullptr);
@@ -97,4 +127,12 @@ bool IsFloorSupportedRef(const TensorInfo& input,
const TensorInfo& output,
std::string* reasonIfUnsupported = nullptr);
+bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported = nullptr);
+
+bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported = nullptr);
+
}