diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 25c2bafe2f..45f108c2f8 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -34,6 +34,7 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported, &FalseFunc<Params...>, floatFuncPtr, uint8FuncPtr, + &FalseFunc<Params...>, std::forward<Params>(params)...); } @@ -105,10 +106,12 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &FalseFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -119,12 +122,14 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, input.GetDataType(), &TrueFunc<>, &FalseInputFuncF32<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -135,12 +140,14 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, input.GetDataType(), &FalseInputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, |