From 3695f9af9db2c14acee9af2fd68c44c737faa6ce Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 23 Apr 2018 17:41:22 +0100 Subject: COMPMID-1048 Add NHWC data format support to Winograd output transform 4x4_3x3 https://confluence.arm.com/display/MLENG/Winograd+Output+Transform%3A+NCHW+vs+NHWC+on+OpenCL Change-Id: I6995f5cef759ba70ebd96d545b952041b6f1f36e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128729 Reviewed-by: Gian Marco Iodice Tested-by: Jenkins --- .../fixtures/WinogradConvolutionLayerFixture.h | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'tests/validation/fixtures/WinogradConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index ef596e0bae..e23368add6 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -36,6 +36,7 @@ #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ConvolutionLayer.h" #include "tests/validation/reference/GEMM.h" +#include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/Winograd.h" @@ -440,10 +441,8 @@ public: template void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type) { - TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); - - _target = compute_target(input_shape, output_shape, winograd_info, data_type); - _reference = compute_reference(input_shape, output_shape, winograd_info, data_type); + _target = compute_target(input_shape, winograd_info, data_type); + _reference = compute_reference(input_shape, winograd_info, data_type); } protected: @@ -467,8 +466,10 @@ protected: } } - TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type) { + TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); + // Create tensors TensorType src = create_tensor(input_shape, data_type); TensorType dst = create_tensor(output_shape, data_type, 1, 0, QuantizationInfo(), winograd_info.output_data_layout); @@ -495,8 +496,11 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type) + SimpleTensor compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type) { + winograd_info.output_data_layout = DataLayout::NCHW; + TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); + // Create reference SimpleTensor src{ input_shape, data_type }; SimpleTensor bias{ TensorShape(input_shape[0]), data_type }; -- cgit v1.2.1