aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp4
-rw-r--r--tests/validation/CL/Winograd.cpp4
2 files changed, 6 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index a70389ab6c..70bf3ae593 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -158,6 +158,10 @@ Status CLWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
const Size2D kernel_size = Size2D(weights->tensor_shape()[idx_width], weights->tensor_shape()[idx_height]);
const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, input->data_layout());
+ //FP16 implementation of winograd is slower than direct convolution.
+ //The following check needs to be removed when fp16 winograd is faster than direct convolution (COMPMID-1266)
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
{
diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp
index 3762e397ac..930f7aa8ce 100644
--- a/tests/validation/CL/Winograd.cpp
+++ b/tests/validation/CL/Winograd.cpp
@@ -834,7 +834,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture, fram
TEST_SUITE_END() // Conv1x5
TEST_SUITE_END() // FP32
-
+#ifdef WINOGRAD_F16_SUPPORT //to be reintroduced after COMPMID-1266 is resolved
TEST_SUITE(FP16)
using CLWinogradConvolutionLayerFastMathFixture16 = WinogradConvolutionLayerFastMathValidationFixture<CLTensor, CLAccessor, CLWinogradConvolutionLayer, half>;
@@ -977,7 +977,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
TEST_SUITE_END() // Conv1x5
TEST_SUITE_END() // FP16
-
+#endif /*#ifdef WINOGRAD_F16_SUPPORT*/
TEST_SUITE_END() // ConvolutionLayer
TEST_SUITE_END() // Winograd
TEST_SUITE_END() // CL