From e4e3b2ead5b6720af8039f3c9ac15ea6b51b915f Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Wed, 7 Sep 2022 12:38:46 +0100 Subject: Disable Winograd on fp16 if fast-math = false - That would force CpuConv2d::get_convolution_method() choose GEMM_CONV2D or GEMM methods instead. Resolves: COMPMID-5531 Signed-off-by: Ramy Elgammal Change-Id: I4dec8772e8c150da003d9a89c1d036057c4d28b0 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8233 Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice --- src/cpu/operators/CpuWinogradConv2d.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'src/cpu') diff --git a/src/cpu/operators/CpuWinogradConv2d.cpp b/src/cpu/operators/CpuWinogradConv2d.cpp index 81cf651b76..c4edd89964 100644 --- a/src/cpu/operators/CpuWinogradConv2d.cpp +++ b/src/cpu/operators/CpuWinogradConv2d.cpp @@ -163,7 +163,7 @@ void CpuWinogradConv2d::configure(const ITensorInfo *src, const ITensorInfo *wei const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, bool enable_fast_math) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate(src, weights, biases, dst, conv_info, act_info, enable_fast_math)); ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, act_info, enable_fast_math); ARM_COMPUTE_UNUSED(biases); const DataType data_type = src->data_type(); @@ -294,6 +294,12 @@ Status CpuWinogradConv2d::validate(const ITensorInfo *src, const ITensorInfo *we ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info)); + // Disable winograd for fp16 if fast math is false. + if(!enable_fast_math) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); + } + const Tensor4DShape kernel_shape{ internal_get_shape(weights) }; arm_conv::winograd::WinogradImpl winograd_impl{}; -- cgit v1.2.1