aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuWinogradConv2d.cpp
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2022-09-07 12:38:46 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2022-09-08 11:19:59 +0000
commite4e3b2ead5b6720af8039f3c9ac15ea6b51b915f (patch)
tree954353cf90c8c52d51d02f6bd971f4ea84a7f646 /src/cpu/operators/CpuWinogradConv2d.cpp
parent211a55d8218764c0a20d69d4cbdaea1906291c6b (diff)
downloadComputeLibrary-e4e3b2ead5b6720af8039f3c9ac15ea6b51b915f.tar.gz
Disable Winograd on fp16 if fast-math = false
- That would force CpuConv2d::get_convolution_method() choose GEMM_CONV2D or GEMM methods instead. Resolves: COMPMID-5531 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: I4dec8772e8c150da003d9a89c1d036057c4d28b0 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8233 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuWinogradConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuWinogradConv2d.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/cpu/operators/CpuWinogradConv2d.cpp b/src/cpu/operators/CpuWinogradConv2d.cpp
index 81cf651b76..c4edd89964 100644
--- a/src/cpu/operators/CpuWinogradConv2d.cpp
+++ b/src/cpu/operators/CpuWinogradConv2d.cpp
@@ -163,7 +163,7 @@ void CpuWinogradConv2d::configure(const ITensorInfo *src, const ITensorInfo *wei
const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, bool enable_fast_math)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate(src, weights, biases, dst, conv_info, act_info, enable_fast_math));
ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, act_info, enable_fast_math);
ARM_COMPUTE_UNUSED(biases);
const DataType data_type = src->data_type();
@@ -294,6 +294,12 @@ Status CpuWinogradConv2d::validate(const ITensorInfo *src, const ITensorInfo *we
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info));
+ // Disable winograd for fp16 if fast math is false.
+ if(!enable_fast_math)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32);
+ }
+
const Tensor4DShape kernel_shape{ internal_get_shape(weights) };
arm_conv::winograd::WinogradImpl winograd_impl{};