aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/Im2Col.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/Im2Col.cpp')
-rw-r--r--tests/validation/reference/Im2Col.cpp12
1 files changed, 5 insertions, 7 deletions
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp
index 825f0a6ee1..d309b7d5e6 100644
--- a/tests/validation/reference/Im2Col.cpp
+++ b/tests/validation/reference/Im2Col.cpp
@@ -88,14 +88,12 @@ SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape,
if(src.data_layout() == DataLayout::NHWC)
{
SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
- SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
-
- im2col_nchw(src_nchw, dst_nchw, kernel_dims, conv_info, has_bias);
-
- return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U));
+ im2col_nchw(src_nchw, dst, kernel_dims, conv_info, has_bias);
+ }
+ else
+ {
+ im2col_nchw(src, dst, kernel_dims, conv_info, has_bias);
}
-
- im2col_nchw(src, dst, kernel_dims, conv_info, has_bias);
return dst;
}