diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 45f108c2f8..78e44bd6a3 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -35,6 +35,7 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported, floatFuncPtr, uint8FuncPtr, &FalseFunc<Params...>, + &FalseFunc<Params...>, std::forward<Params>(params)...); } @@ -111,7 +112,8 @@ bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, &FalseFunc<>, &TrueFunc<>, &TrueFunc<>, - &TrueFunc<>); + &TrueFunc<>, + &FalseFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -123,13 +125,15 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, &TrueFunc<>, &FalseInputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -141,13 +145,15 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, &FalseInputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, @@ -415,10 +421,13 @@ bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input, Optional<std::string &> reasonIfUnsupported) const { ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0, @@ -463,10 +472,13 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsPadSupported(const TensorInfo& input, |