aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-04-23 17:41:22 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commit3695f9af9db2c14acee9af2fd68c44c737faa6ce (patch)
tree87aa336d6263cb00b01d5277b19178e80782f57a
parentb62280aca3148dd6762e57e5af3da0cb0a9e2db5 (diff)
downloadComputeLibrary-3695f9af9db2c14acee9af2fd68c44c737faa6ce.tar.gz
COMPMID-1048 Add NHWC data format support to Winograd output transform 4x4_3x3
https://confluence.arm.com/display/MLENG/Winograd+Output+Transform%3A+NCHW+vs+NHWC+on+OpenCL Change-Id: I6995f5cef759ba70ebd96d545b952041b6f1f36e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128729 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLWinogradOutputTransformKernel.h2
-rw-r--r--src/core/CL/CLKernelLibrary.cpp1
-rw-r--r--src/core/CL/cl_kernels/winograd.cl189
-rw-r--r--src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp3
-rw-r--r--tests/datasets/WinogradOutputTransformDataset.h13
-rw-r--r--tests/validation/fixtures/WinogradConvolutionLayerFixture.h16
-rw-r--r--tests/validation/reference/Winograd.cpp4
7 files changed, 217 insertions, 11 deletions
diff --git a/arm_compute/core/CL/kernels/CLWinogradOutputTransformKernel.h b/arm_compute/core/CL/kernels/CLWinogradOutputTransformKernel.h
index 5e64a82e48..03e3bf5740 100644
--- a/arm_compute/core/CL/kernels/CLWinogradOutputTransformKernel.h
+++ b/arm_compute/core/CL/kernels/CLWinogradOutputTransformKernel.h
@@ -51,6 +51,7 @@ public:
* @note Winograd output transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
*
* @param[in] input Source tensor with shape [C, N, K, batches]. Data types supported: F32.
* @param[in] bias Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. It can be a nullptr. Data type supported: as @p input
@@ -63,6 +64,7 @@ public:
* @note Winograd output transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
*
* @param[in] input Source tensor with shape [C, N, K, batches]. Data types supported: F32.
* @param[in] bias Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. It can be a nullptr. Data type supported: as @p input
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index b4531b841b..f74c6c8a4a 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -376,6 +376,7 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "winograd_output_transform_2x2_3x3_nchw", "winograd.cl" },
{ "winograd_output_transform_4x4_3x3_nchw", "winograd.cl" },
{ "winograd_output_transform_4x4_5x5_nchw", "winograd.cl" },
+ { "winograd_output_transform_4x4_3x3_nhwc", "winograd.cl" },
{ "YUYV422_to_IYUV_bt709", "color_convert.cl" },
{ "YUYV422_to_NV12_bt709", "color_convert.cl" },
{ "YUYV422_to_RGB888_bt709", "color_convert.cl" },
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index 0458e53734..14bebb4b0b 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -1417,6 +1417,195 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
}
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data format is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_3x3_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ // Each thread stores a 4x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Load the values across the 36 channels to compose the 6x6 tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+
+ float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
+
+ // Compute out00, out01, out02 and out03
+ float out00 = d01 + d21 + d41 + d11 + d31;
+ float out01 = d01 + d21 + d41 + d11 + d31;
+ float out02 = d01 + d21 + d41 + d11 + d31;
+ float out03 = d01 + d21 + d41 + d11 + d31;
+
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+
+ out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
+ out01 += k1 - d02 - d12 - d22 - d32 - d42;
+ out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
+ out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
+
+ // Compute out10, out11, out12 and out13
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+
+ k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
+
+ out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
+ out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
+ out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
+ out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
+
+ // Compute out20, out21, out22 and out23
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+
+ k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
+
+ out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
+ out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
+ out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
+ out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
+
+ // Compute out30, out31, out32 and out33
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+
+ k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
+
+ out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
+ out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
+ out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
+ out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
+
+ int y_in = get_global_id(1);
+ int x_out = get_global_id(0);
+ int y_out = (y_in % NUM_TILES_X) * 4;
+ int z_out = (y_in / NUM_TILES_X) * 4;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+
+ out10 += (float)b;
+ out11 += (float)b;
+ out12 += (float)b;
+ out13 += (float)b;
+
+ out20 += (float)b;
+ out21 += (float)b;
+ out22 += (float)b;
+ out23 += (float)b;
+
+ out30 += (float)b;
+ out31 += (float)b;
+ out32 += (float)b;
+ out33 += (float)b;
+
+#endif // defined(HAS_BIAS)
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ // Store the 4x4 output tile
+ *((__global float *)(dst_addr + 0 * dst_stride_y + 0 * dst_stride_z)) = out00;
+ *((__global float *)(dst_addr + 1 * dst_stride_y + 0 * dst_stride_z)) = out01;
+ *((__global float *)(dst_addr + 2 * dst_stride_y + 0 * dst_stride_z)) = out02;
+ *((__global float *)(dst_addr + 3 * dst_stride_y + 0 * dst_stride_z)) = out03;
+ *((__global float *)(dst_addr + 0 * dst_stride_y + 1 * dst_stride_z)) = out10;
+ *((__global float *)(dst_addr + 1 * dst_stride_y + 1 * dst_stride_z)) = out11;
+ *((__global float *)(dst_addr + 2 * dst_stride_y + 1 * dst_stride_z)) = out12;
+ *((__global float *)(dst_addr + 3 * dst_stride_y + 1 * dst_stride_z)) = out13;
+ *((__global float *)(dst_addr + 0 * dst_stride_y + 2 * dst_stride_z)) = out20;
+ *((__global float *)(dst_addr + 1 * dst_stride_y + 2 * dst_stride_z)) = out21;
+ *((__global float *)(dst_addr + 2 * dst_stride_y + 2 * dst_stride_z)) = out22;
+ *((__global float *)(dst_addr + 3 * dst_stride_y + 2 * dst_stride_z)) = out23;
+ *((__global float *)(dst_addr + 0 * dst_stride_y + 3 * dst_stride_z)) = out30;
+ *((__global float *)(dst_addr + 1 * dst_stride_y + 3 * dst_stride_z)) = out31;
+ *((__global float *)(dst_addr + 2 * dst_stride_y + 3 * dst_stride_z)) = out32;
+ *((__global float *)(dst_addr + 3 * dst_stride_y + 3 * dst_stride_z)) = out33;
+}
+
#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
({ \
comm_fact.s0 = d1 + d2; \
diff --git a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
index 5c0a7351eb..416d8e8d5f 100644
--- a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
@@ -48,7 +48,6 @@ namespace
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(winograd_info.output_data_layout != DataLayout::NCHW);
const PadStrideInfo conv_info = winograd_info.convolution_info;
const Size2D output_tile_size = winograd_info.output_tile_size;
@@ -146,7 +145,7 @@ void CLWinogradOutputTransformKernel::configure(const ICLTensor *input, const IC
build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(num_tiles_x));
// Create kernel
- std::string kernel_name = "winograd_output_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_nchw";
+ std::string kernel_name = "winograd_output_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_" + lower_string(string_from_data_layout(winograd_info.output_data_layout));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure kernel window
diff --git a/tests/datasets/WinogradOutputTransformDataset.h b/tests/datasets/WinogradOutputTransformDataset.h
index b28df390d1..bf6f5cd2a9 100644
--- a/tests/datasets/WinogradOutputTransformDataset.h
+++ b/tests/datasets/WinogradOutputTransformDataset.h
@@ -118,6 +118,12 @@ public:
add_config(TensorShape(7U, 4U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW));
add_config(TensorShape(24U, 16U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW));
add_config(TensorShape(7U, 12U, 16U, 5U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW));
+ // NHWC
+ add_config(TensorShape(13U, 4U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC));
+ add_config(TensorShape(13U, 6U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC));
+ add_config(TensorShape(7U, 117U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC));
+ add_config(TensorShape(7U, 4U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC));
+ add_config(TensorShape(24U, 16U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC));
// (4x4, 5x5)
add_config(TensorShape(13U, 1U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW));
@@ -149,6 +155,13 @@ public:
add_config(TensorShape(64U, 3136U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW));
add_config(TensorShape(32U, 784U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW));
add_config(TensorShape(13U, 196U, 36U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW));
+ // NHWC
+ add_config(TensorShape(64U, 3136U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC));
+ add_config(TensorShape(32U, 784U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC));
+ add_config(TensorShape(13U, 196U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC));
+ add_config(TensorShape(64U, 3136U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC));
+ add_config(TensorShape(32U, 784U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC));
+ add_config(TensorShape(13U, 196U, 36U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC));
}
};
} // namespace datasets
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index ef596e0bae..e23368add6 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -36,6 +36,7 @@
#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/Utils.h"
#include "tests/validation/reference/Winograd.h"
@@ -440,10 +441,8 @@ public:
template <typename...>
void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type)
{
- TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
-
- _target = compute_target(input_shape, output_shape, winograd_info, data_type);
- _reference = compute_reference(input_shape, output_shape, winograd_info, data_type);
+ _target = compute_target(input_shape, winograd_info, data_type);
+ _reference = compute_reference(input_shape, winograd_info, data_type);
}
protected:
@@ -467,8 +466,10 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type)
{
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
+
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type);
TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), winograd_info.output_data_layout);
@@ -495,8 +496,11 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type)
{
+ winograd_info.output_data_layout = DataLayout::NCHW;
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
+
// Create reference
SimpleTensor<T> src{ input_shape, data_type };
SimpleTensor<T> bias{ TensorShape(input_shape[0]), data_type };
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp
index 194a78e95f..197d218129 100644
--- a/tests/validation/reference/Winograd.cpp
+++ b/tests/validation/reference/Winograd.cpp
@@ -333,8 +333,6 @@ SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const Tenso
template <typename T>
SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info)
{
- ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format");
-
const PadStrideInfo conv_info = winograd_info.convolution_info;
const Size2D input_dimensions = winograd_info.input_dimensions;
const Size2D output_tile_size = winograd_info.output_tile_size;
@@ -350,7 +348,7 @@ SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const Simpl
const unsigned int out_tile_h = output_tile_size.height;
ARM_COMPUTE_ERROR_ON(in.shape()[2] != (in_tile_w * in_tile_h));
- ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[2]);
+ ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[get_data_layout_dimension_index(winograd_info.output_data_layout, DataLayoutDimension::CHANNEL)]);
// Compute tile dimensions
// Input tile dimensions