From 2b4d88e34ac1f965417fd236fd4786f26bae2042 Mon Sep 17 00:00:00 2001 From: kevmay01 Date: Thu, 24 Jan 2019 14:05:09 +0000 Subject: IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater * Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32 --- src/backends/reference/RefLayerSupport.cpp | 38 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) (limited to 'src/backends/reference/RefLayerSupport.cpp') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 45f108c2f8..78e44bd6a3 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -35,6 +35,7 @@ bool IsSupportedForDataTypeRef(Optional reasonIfUnsupported, floatFuncPtr, uint8FuncPtr, &FalseFunc, + &FalseFunc, std::forward(params)...); } @@ -111,7 +112,8 @@ bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, &FalseFunc<>, &TrueFunc<>, &TrueFunc<>, - &TrueFunc<>); + &TrueFunc<>, + &FalseFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -123,13 +125,15 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, &TrueFunc<>, &FalseInputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -141,13 +145,15 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, &FalseInputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, @@ -415,10 +421,13 @@ bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input, Optional reasonIfUnsupported) const { ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0, @@ -463,10 +472,13 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsPadSupported(const TensorInfo& input, -- cgit v1.2.1