diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/fixtures/WinogradConvolutionLayerFixture.h | 16 | ||||
-rw-r--r-- | tests/validation/reference/Winograd.cpp | 4 |
2 files changed, 11 insertions, 9 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index ef596e0bae..e23368add6 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -36,6 +36,7 @@ #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ConvolutionLayer.h" #include "tests/validation/reference/GEMM.h" +#include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/Winograd.h" @@ -440,10 +441,8 @@ public: template <typename...> void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type) { - TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); - - _target = compute_target(input_shape, output_shape, winograd_info, data_type); - _reference = compute_reference(input_shape, output_shape, winograd_info, data_type); + _target = compute_target(input_shape, winograd_info, data_type); + _reference = compute_reference(input_shape, winograd_info, data_type); } protected: @@ -467,8 +466,10 @@ protected: } } - TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type) { + TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); + // Create tensors TensorType src = create_tensor<TensorType>(input_shape, data_type); TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), winograd_info.output_data_layout); @@ -495,8 +496,11 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type) { + winograd_info.output_data_layout = DataLayout::NCHW; + TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); + // Create reference SimpleTensor<T> src{ input_shape, data_type }; SimpleTensor<T> bias{ TensorShape(input_shape[0]), data_type }; diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp index 194a78e95f..197d218129 100644 --- a/tests/validation/reference/Winograd.cpp +++ b/tests/validation/reference/Winograd.cpp @@ -333,8 +333,6 @@ SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const Tenso template <typename T> SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info) { - ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format"); - const PadStrideInfo conv_info = winograd_info.convolution_info; const Size2D input_dimensions = winograd_info.input_dimensions; const Size2D output_tile_size = winograd_info.output_tile_size; @@ -350,7 +348,7 @@ SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const Simpl const unsigned int out_tile_h = output_tile_size.height; ARM_COMPUTE_ERROR_ON(in.shape()[2] != (in_tile_w * in_tile_h)); - ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[2]); + ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[get_data_layout_dimension_index(winograd_info.output_data_layout, DataLayoutDimension::CHANNEL)]); // Compute tile dimensions // Input tile dimensions |