diff options
author | Pablo Tello <pablo.tello@arm.com> | 2018-04-23 16:11:45 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:54 +0000 |
commit | 764b1af5f41749624c3900fc65c9dad3b1adf98c (patch) | |
tree | 9d8397de05d86eb40959d24058939c46c36f0d39 /tests/validation | |
parent | 3ca9786fe8ed00ad03963cae6a9eef7bb2fe630e (diff) | |
download | ComputeLibrary-764b1af5f41749624c3900fc65c9dad3b1adf98c.tar.gz |
COMPMID-1070: Rewrote im2col nhwc reference
The new reference computes directly the results without the need
of using temporary tensors and perform permutations which is problematic
for big 4k tensors as we get bad_alloc execptions.
Change-Id: I1fb0a495c4e9ca3356de76c9298832db8ace794a
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128683
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/fixtures/Im2ColFixture.h | 29 | ||||
-rw-r--r-- | tests/validation/reference/Im2Col.cpp | 74 | ||||
-rw-r--r-- | tests/validation/reference/Im2Col.h | 2 | ||||
-rw-r--r-- | tests/validation/reference/Utils.h | 6 |
4 files changed, 80 insertions, 31 deletions
diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index f403aa9d21..b6e4cd0023 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -49,7 +49,7 @@ class Im2ColValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout) + void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout) { _kernel_dims = kernel_dims; _conv_info = conv_info; @@ -59,16 +59,17 @@ public: if(_data_layout == DataLayout::NHWC) { - permute(shape, PermutationVector(2U, 0U, 1U)); + permute(input_shape, PermutationVector(2U, 0U, 1U)); } - TensorShape output_shape; - TensorInfo input_info(shape, 1, data_type); + TensorInfo input_info(input_shape, 1, data_type); input_info.set_data_layout(_data_layout); - output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U)); - _target = compute_target(shape, output_shape, data_type); - _reference = compute_reference(shape, output_shape, data_type); + const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U)); + + _target = compute_target(input_shape, output_shape, data_type); + + compute_reference(input_shape, output_shape, data_type); } protected: @@ -78,10 +79,10 @@ protected: library->fill_tensor_uniform(tensor, 0); } - TensorType compute_target(const TensorShape &shape, const TensorShape &output_shape, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, data_type, 1, 0, _quant_info, _data_layout); + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, _quant_info, _data_layout); TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, _quant_info, _data_layout); // Create and configure function @@ -107,17 +108,15 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &output_shape, DataType data_type) + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) { // Create reference - SimpleTensor<T> src{ shape, data_type, 1, 0, _quant_info, _data_layout }; - + SimpleTensor<T> src{ input_shape, data_type, 1, 0, _quant_info, _data_layout }; + _reference = SimpleTensor<T>(output_shape, data_type, 1, 0, _quant_info, DataLayout::NCHW); // Fill reference fill(src); - - return reference::im2col<T>(src, output_shape, _kernel_dims, _conv_info, _has_bias); + reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias); } - TensorType _target{}; SimpleTensor<T> _reference{}; Size2D _kernel_dims{}; diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index d309b7d5e6..5685b60026 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -40,6 +40,7 @@ namespace reference template <typename T> void im2col_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { + ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NCHW); // Create reference const int pad_x = conv_info.pad().first; const int pad_y = conv_info.pad().second; @@ -81,26 +82,73 @@ void im2col_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D } template <typename T> -SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +void im2col_nhwc(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { - SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; - - if(src.data_layout() == DataLayout::NHWC) + ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC); + const int pad_x = conv_info.pad().first; + const int pad_y = conv_info.pad().second; + const int stride_x = conv_info.stride().first; + const int stride_y = conv_info.stride().second; + const int kernel_width = kernel_dims.width; + const int kernel_height = kernel_dims.height; + const int src_width = src.shape().y(); + const int src_height = src.shape().z(); + const int src_depth = src.shape().x(); + const int batches = src.shape().total_size_upper(3); + const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; + int dst_idx = 0; + for(int b = 0; b < batches; ++b) { - SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); - im2col_nchw(src_nchw, dst, kernel_dims, conv_info, has_bias); + for(int y = -pad_y; y <= (src_height + pad_y - kernel_height); y += stride_y) + { + for(int x = -pad_x; x <= (src_width + pad_x - kernel_width); x += stride_x) + { + for(int z = 0; z < src_depth; ++z) + { + for(int patch_y = y; patch_y < (y + kernel_height); ++patch_y) + { + for(int patch_x = x; patch_x < (x + kernel_width); ++patch_x) + { + dst[dst_idx++] = tensor_elem_at(src, Coordinates(z, patch_x, patch_y, b), BorderMode::CONSTANT, static_cast<T>(pad_val)); + } + } + } + + if(has_bias) + { + dst[dst_idx++] = static_cast<T>(1); + } + } + } } - else +} + +template <typename T> +void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +{ + switch(src.data_layout()) { - im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); + case DataLayout::NCHW: + { + im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); + break; + } + case DataLayout::NHWC: + { + im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported."); + break; + } } - - return dst; } -template SimpleTensor<uint8_t> im2col(const SimpleTensor<uint8_t> &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); -template SimpleTensor<half> im2col(const SimpleTensor<half> &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); -template SimpleTensor<float> im2col(const SimpleTensor<float> &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +template void im2col(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +template void im2col(const SimpleTensor<half> &src, SimpleTensor<half> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +template void im2col(const SimpleTensor<float> &src, SimpleTensor<float> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h index 4fe6ea9acf..5277171a2f 100644 --- a/tests/validation/reference/Im2Col.h +++ b/tests/validation/reference/Im2Col.h @@ -35,7 +35,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Utils.h b/tests/validation/reference/Utils.h index 2aa77c6ff7..0e98bbe82b 100644 --- a/tests/validation/reference/Utils.h +++ b/tests/validation/reference/Utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -62,11 +62,13 @@ T tensor_elem_at(const SimpleTensor<T> &src, Coordinates coord, BorderMode borde { const int x = coord.x(); const int y = coord.y(); + const int z = coord.z(); const int width = src.shape().x(); const int height = src.shape().y(); + const int depth = src.shape().z(); // If coordinates beyond range of tensor's width or height - if(x < 0 || y < 0 || x >= width || y >= height) + if(x < 0 || y < 0 || z < 0 || x >= width || y >= height || z >= depth) { if(border_mode == BorderMode::REPLICATE) { |