diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/Im2Col.cpp | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 825f0a6ee1..d309b7d5e6 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -88,14 +88,12 @@ SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape, if(src.data_layout() == DataLayout::NHWC) { SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); - SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); - - im2col_nchw(src_nchw, dst_nchw, kernel_dims, conv_info, has_bias); - - return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U)); + im2col_nchw(src_nchw, dst, kernel_dims, conv_info, has_bias); + } + else + { + im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); } - - im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); return dst; } |