aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/Winograd.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/Winograd.cpp')
-rw-r--r--tests/validation/reference/Winograd.cpp4
1 files changed, 1 insertions, 3 deletions
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp
index 194a78e95f..197d218129 100644
--- a/tests/validation/reference/Winograd.cpp
+++ b/tests/validation/reference/Winograd.cpp
@@ -333,8 +333,6 @@ SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const Tenso
template <typename T>
SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info)
{
- ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format");
-
const PadStrideInfo conv_info = winograd_info.convolution_info;
const Size2D input_dimensions = winograd_info.input_dimensions;
const Size2D output_tile_size = winograd_info.output_tile_size;
@@ -350,7 +348,7 @@ SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const Simpl
const unsigned int out_tile_h = output_tile_size.height;
ARM_COMPUTE_ERROR_ON(in.shape()[2] != (in_tile_w * in_tile_h));
- ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[2]);
+ ARM_COMPUTE_ERROR_ON(in.shape()[0] != out.shape()[get_data_layout_dimension_index(winograd_info.output_data_layout, DataLayoutDimension::CHANNEL)]);
// Compute tile dimensions
// Input tile dimensions