aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefLayerSupport.cpp')
-rw-r--r--src/armnn/backends/RefLayerSupport.cpp99
1 files changed, 98 insertions, 1 deletions
diff --git a/src/armnn/backends/RefLayerSupport.cpp b/src/armnn/backends/RefLayerSupport.cpp
index 0b94656ded..ca4fca6f31 100644
--- a/src/armnn/backends/RefLayerSupport.cpp
+++ b/src/armnn/backends/RefLayerSupport.cpp
@@ -10,7 +10,6 @@
#include <armnn/Tensor.hpp>
#include <boost/core/ignore_unused.hpp>
-
#include "InternalTypes.hpp"
using namespace boost;
@@ -27,15 +26,18 @@ bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
{
return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
dataType,
+ &FalseFunc<Params...>,
floatFuncPtr,
uint8FuncPtr,
std::forward<Params>(params)...);
}
bool IsActivationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const ActivationDescriptor& descriptor,
std::string* reasonIfUnsupported)
{
+ ignore_unused(output);
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
@@ -57,6 +59,11 @@ bool IsAdditionSupportedRef(const TensorInfo& input0,
}
bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& mean,
+ const TensorInfo& var,
+ const TensorInfo& beta,
+ const TensorInfo& gamma,
const BatchNormalizationDescriptor& descriptor,
std::string* reasonIfUnsupported)
{
@@ -94,12 +101,16 @@ bool IsConvolution2dSupportedRef(const TensorInfo& input,
}
bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const DepthwiseConvolution2dDescriptor& descriptor,
const TensorInfo& weights,
+ const TensorInfo& biases,
std::string* reasonIfUnsupported)
{
+ ignore_unused(output);
ignore_unused(descriptor);
ignore_unused(weights);
+ ignore_unused(biases);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
@@ -107,10 +118,16 @@ bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
}
bool IsFullyConnectedSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& weights,
+ const TensorInfo& biases,
const FullyConnectedDescriptor& descriptor,
std::string* reasonIfUnsupported)
{
+ ignore_unused(output);
ignore_unused(descriptor);
+ ignore_unused(weights);
+ ignore_unused(biases);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
@@ -127,8 +144,10 @@ bool IsInputSupportedRef(const TensorInfo& input,
}
bool IsL2NormalizationSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
std::string* reasonIfUnsupported)
{
+ ignore_unused(output);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
@@ -148,9 +167,11 @@ bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
bool IsMultiplicationSupportedRef(const TensorInfo& input0,
const TensorInfo& input1,
+ const TensorInfo& output,
std::string* reasonIfUnsupported)
{
ignore_unused(input1);
+ ignore_unused(output);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input0.GetDataType(),
&TrueFunc<>,
@@ -212,9 +233,11 @@ bool IsResizeBilinearSupportedRef(const TensorInfo& input,
}
bool IsSoftmaxSupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
const SoftmaxDescriptor& descriptor,
std::string* reasonIfUnsupported)
{
+ ignore_unused(output);
ignore_unused(descriptor);
return IsSupportedForDataTypeRef(reasonIfUnsupported,
input.GetDataType(),
@@ -264,4 +287,78 @@ bool IsFloorSupportedRef(const TensorInfo& input,
&FalseFuncU8<>);
}
+bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
+ const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
+ const TensorInfo& output, const LstmDescriptor& descriptor,
+ const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
+ const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
+ const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
+ const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
+ const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
+ const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
+ const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
+ const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
+ const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
+{
+ ignore_unused(input);
+ ignore_unused(outputStateIn);
+ ignore_unused(cellStateIn);
+ ignore_unused(scratchBuffer);
+ ignore_unused(outputStateOut);
+ ignore_unused(cellStateOut);
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ ignore_unused(inputToForgetWeights);
+ ignore_unused(inputToCellWeights);
+ ignore_unused(inputToOutputWeights);
+ ignore_unused(recurrentToForgetWeights);
+ ignore_unused(recurrentToCellWeights);
+ ignore_unused(recurrentToOutputWeights);
+ ignore_unused(forgetGateBias);
+ ignore_unused(cellBias);
+ ignore_unused(outputGateBias);
+ ignore_unused(inputToInputWeights);
+ ignore_unused(recurrentToInputWeights);
+ ignore_unused(cellToInputWeights);
+ ignore_unused(inputGateBias);
+ ignore_unused(projectionWeights);
+ ignore_unused(projectionBias);
+ ignore_unused(cellToForgetWeights);
+ ignore_unused(cellToOutputWeights);
+ return false;
+}
+
+bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported)
+{
+ return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &FalseInputFuncF32<>,
+ &FalseFuncU8<>) &&
+ IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ output.GetDataType(),
+ &FalseOutputFuncF16<>,
+ &TrueFunc<>,
+ &FalseFuncU8<>));
+}
+
+bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported)
+{
+ return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ input.GetDataType(),
+ &FalseInputFuncF16<>,
+ &TrueFunc<>,
+ &FalseFuncU8<>) &&
+ IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ output.GetDataType(),
+ &TrueFunc<>,
+ &FalseOutputFuncF32<>,
+ &FalseFuncU8<>));
+}
+
}