From 3695f9af9db2c14acee9af2fd68c44c737faa6ce Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 23 Apr 2018 17:41:22 +0100 Subject: COMPMID-1048 Add NHWC data format support to Winograd output transform 4x4_3x3 https://confluence.arm.com/display/MLENG/Winograd+Output+Transform%3A+NCHW+vs+NHWC+on+OpenCL Change-Id: I6995f5cef759ba70ebd96d545b952041b6f1f36e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128729 Reviewed-by: Gian Marco Iodice Tested-by: Jenkins --- tests/validation/reference/Winograd.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests/validation/reference') diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp index 194a78e95f..197d218129 100644 --- a/tests/validation/reference/Winograd.cpp +++ b/tests/validation/reference/Winograd.cpp @@ -333,8 +333,6 @@ SimpleTensor winograd_filter_transform(const SimpleTensor &in, const Tenso template SimpleTensor winograd_output_transform(const SimpleTensor &in, const SimpleTensor &b, const TensorShape &output_shape, const WinogradInfo &winograd_info) { - ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format"); - const PadStrideInfo conv_info = winograd_info.convolution_info; const Size2D input_dimensions = winograd_info.input_dimensions; const Size2D output_tile_size = winograd_info.output_tile_size; @@ -350,7 +348,7 @@ SimpleTensor winograd_output_transform(const SimpleTensor &in, const Simpl const unsigned int out_tile_h = output_tile_size.height; ARM_COMPUTE_ERROR_ON(in.shape()[2] != (in_tile_w * in_tile_h)); - ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[2]); + ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[get_data_layout_dimension_index(winograd_info.output_data_layout, DataLayoutDimension::CHANNEL)]); // Compute tile dimensions // Input tile dimensions -- cgit v1.2.1