diff options
Diffstat (limited to 'tests/validation/reference/Im2Col.cpp')
-rw-r--r-- | tests/validation/reference/Im2Col.cpp | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 5685b60026..83ef8b40a5 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -55,11 +55,16 @@ void im2col_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; int dst_idx = 0; + // dst[dst_idx++] will write out of bounds if kernel_height == kernel_width == 1 because lasty will be the bottom padding row + // and this is not present in the dst buffer + const int lasty = src_height + (kernel_height > 1 ? pad_y : 0) - kernel_height; + const int lastx = src_width + (kernel_width > 1 ? pad_x : 0) - kernel_width; + for(int b = 0; b < batches; ++b) { - for(int y = -pad_y; y <= (src_height + pad_y - kernel_height); y += stride_y) + for(int y = -pad_y; y <= lasty; y += stride_y) { - for(int x = -pad_x; x <= (src_width + pad_x - kernel_width); x += stride_x) + for(int x = -pad_x; x <= lastx; x += stride_x) { for(int z = 0; z < src_depth; ++z) { @@ -97,11 +102,15 @@ void im2col_nhwc(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D const int batches = src.shape().total_size_upper(3); const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; int dst_idx = 0; + + const int lasty = src_height + (kernel_height > 1 ? pad_y : 0) - kernel_height; + const int lastx = src_width + (kernel_width > 1 ? pad_x : 0) - kernel_width; + for(int b = 0; b < batches; ++b) { - for(int y = -pad_y; y <= (src_height + pad_y - kernel_height); y += stride_y) + for(int y = -pad_y; y <= lasty; y += stride_y) { - for(int x = -pad_x; x <= (src_width + pad_x - kernel_width); x += stride_x) + for(int x = -pad_x; x <= lastx; x += stride_x) { for(int z = 0; z < src_depth; ++z) { |