diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2018-08-20 18:03:27 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | fb62908bd8148bd347bd204e881156f8ebf7835d (patch) | |
tree | 78843eb937bb64f5e3439b8367f9cb6d7140d7b2 | |
parent | 66cbafb26261fbf091b799d1e5d0600fb08ee513 (diff) | |
download | ComputeLibrary-fb62908bd8148bd347bd204e881156f8ebf7835d.tar.gz |
COMPMID-1494 Optimise NEON im2col and weights reshape for NHWC
Change-Id: I99ebae61024a7bce9d17292a02c28626ae6c29d5
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144872
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r-- | arm_compute/core/NEON/kernels/NEIm2ColKernel.h | 4 | ||||
-rw-r--r-- | src/core/NEON/kernels/NEIm2ColKernel.cpp | 219 | ||||
-rw-r--r-- | src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 27 | ||||
-rw-r--r-- | tests/validation/CL/Im2Col.cpp | 72 | ||||
-rw-r--r-- | tests/validation/NEON/Im2Col.cpp | 30 | ||||
-rw-r--r-- | tests/validation/fixtures/Im2ColFixture.h | 8 | ||||
-rw-r--r-- | tests/validation/reference/Im2Col.cpp | 66 | ||||
-rw-r--r-- | tests/validation/reference/Im2Col.h | 3 |
8 files changed, 218 insertions, 211 deletions
diff --git a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h index ec89f0f713..f76521f770 100644 --- a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h +++ b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h @@ -109,11 +109,11 @@ public: void run(const Window &window, const ThreadInfo &info) override; private: - /** Template function to run the im2col + /** Template function to run im2col * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template <typename T, bool has_pads> + template <typename T, bool has_pads, bool is_nchw> void run_im2col(const Window &window); /** Common signature for all the specialised im2col functions diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index e5d31289a4..2c51eae468 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -90,22 +90,22 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen } template <typename T, bool has_pads> -inline void linearize_volume(const uint8_t *const in_ptr, - T *out_ptr, - bool has_bias, - int top_left_x, - int top_left_y, - int kernel_width, - int kernel_height, - int kernel_depth, - int input_w, - int input_h, - int input_stride_x, - int input_stride_y, - int input_stride_z, - int pad_value, - int dilation_x, - int dilation_y) +inline void linearize_volume_nchw(const uint8_t *const in_ptr, + T *out_ptr, + bool has_bias, + int top_left_x, + int top_left_y, + int kernel_width, + int kernel_height, + int kernel_depth, + int input_w, + int input_h, + int input_stride_x, + int input_stride_y, + int input_stride_z, + int pad_value, + int dilation_x, + int dilation_y) { const int kernel_size2 = kernel_width * kernel_height; const int x_e = top_left_x + kernel_width * dilation_x; @@ -186,9 +186,62 @@ inline void linearize_volume(const uint8_t *const in_ptr, *out_ptr = static_cast<T>(1); } } -} // namespace template <typename T, bool has_pads> +inline void linearize_volume_nhwc(const uint8_t *const in_ptr, + T *out_ptr, + bool has_bias, + int start_x, + int start_y, + int kernel_width, + int kernel_height, + int input_w, + int input_h, + int input_c, + int input_stride_y, + int input_stride_z, + int pad_value, + int dilation_x, + int dilation_y) +{ + const int end_x = start_x + kernel_width * dilation_x; + const int end_y = start_y + kernel_height * dilation_y; + const int pad_quant = kernel_width * input_c; + + for(int y = start_y; y < end_y; y += dilation_y) + { + if(y < 0 || y >= input_h) + { + memset(out_ptr, pad_value, pad_quant * sizeof(T)); + out_ptr += pad_quant; + } + else + { + for(int x = start_x; x < end_x; x += dilation_x) + { + if(x < 0 || x >= input_w) + { + memset(out_ptr, pad_value, input_c * sizeof(T)); + out_ptr += input_c; + } + else + { + memcpy(out_ptr, reinterpret_cast<const T *>(in_ptr + (y * input_stride_z + x * input_stride_y)), input_c * sizeof(T)); + out_ptr += input_c; + } + } + } + } + + // Append 1 if the convolution layer has biases + if(has_bias) + { + *out_ptr = static_cast<T>(1); + } +} +} // namespace + +template <typename T, bool has_pads, bool is_nchw> void NEIm2ColKernel::run_im2col(const Window &window) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); @@ -199,25 +252,17 @@ void NEIm2ColKernel::run_im2col(const Window &window) const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const int kernel_depth = _input->info()->dimension(channel_idx); const int input_w = _input->info()->dimension(width_idx); const int input_h = _input->info()->dimension(height_idx); - const int input_stride_x = _input->info()->strides_in_bytes()[width_idx]; - const int input_stride_y = _input->info()->strides_in_bytes()[height_idx]; - const int input_stride_z = _input->info()->strides_in_bytes()[channel_idx]; - const int offset = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().offset : 0; - - int pad_left = 0; - int pad_top = 0; - int stride_x = 0; - int stride_y = 0; - pad_left = _conv_info.pad_left(); - pad_top = _conv_info.pad_top(); - std::tie(stride_x, stride_y) = _conv_info.stride(); - - // Setup input window - const int start_x = -pad_left; - const int start_y = -pad_top; + const int input_c = _input->info()->dimension(channel_idx); + const int input_stride_x = _input->info()->strides_in_bytes().x(); + const int input_stride_y = _input->info()->strides_in_bytes().y(); + const int input_stride_z = _input->info()->strides_in_bytes().z(); + const int pad_left = _conv_info.pad_left(); + const int pad_top = _conv_info.pad_top(); + const int stride_x = _conv_info.stride().first; + const int stride_y = _conv_info.stride().second; + const int pad_value = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().offset : 0; Window window_in_out(window); // The first three dimensions of the input and output are increased by the inner loops @@ -231,30 +276,51 @@ void NEIm2ColKernel::run_im2col(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { - const int top_left_x = id[width_idx] * stride_x + start_x; - const int top_left_y = id[height_idx] * stride_y + start_y; + const int start_w = id[width_idx] * stride_x - pad_left; + const int start_h = id[height_idx] * stride_y - pad_top; // Get pointers const uint8_t *const input_ptr = in.ptr(); auto output_ptr = reinterpret_cast<T *>(out.ptr() + (id[width_idx] + id[height_idx] * _convolved_dims.first) * _output->info()->strides_in_bytes().y()); // Linearize volume - linearize_volume<T, has_pads>(input_ptr, - output_ptr, - _has_bias, - top_left_x, - top_left_y, - static_cast<int>(_kernel_width), - static_cast<int>(_kernel_height), - kernel_depth, - input_w, - input_h, - input_stride_x, - input_stride_y, - input_stride_z, - offset, - _dilation.x(), - _dilation.y()); + if(is_nchw) + { + linearize_volume_nchw<T, has_pads>(input_ptr, + output_ptr, + _has_bias, + start_w, + start_h, + _kernel_width, + _kernel_height, + input_c, + input_w, + input_h, + input_stride_x, + input_stride_y, + input_stride_z, + pad_value, + _dilation.x(), + _dilation.y()); + } + else + { + linearize_volume_nhwc<T, has_pads>(input_ptr, + output_ptr, + _has_bias, + start_w, + start_h, + _kernel_width, + _kernel_height, + input_w, + input_h, + input_c, + input_stride_y, + input_stride_z, + pad_value, + _dilation.x(), + _dilation.y()); + } }, in, out); } @@ -286,22 +352,45 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size _conv_info, _dilation); _has_bias = has_bias; - switch(_input->info()->data_type()) + if(data_layout == DataLayout::NCHW) { - case DataType::F32: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false> : &NEIm2ColKernel::run_im2col<float, true>; - break; + switch(_input->info()->data_type()) + { + case DataType::F32: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false, true> : &NEIm2ColKernel::run_im2col<float, true, true>; + break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false> : &NEIm2ColKernel::run_im2col<float16_t, true>; - break; + case DataType::F16: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false, true> : &NEIm2ColKernel::run_im2col<float16_t, true, true>; + break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::QASYMM8: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<qasymm8_t, false> : &NEIm2ColKernel::run_im2col<qasymm8_t, true>; - break; - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; + case DataType::QASYMM8: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<qasymm8_t, false, true> : &NEIm2ColKernel::run_im2col<qasymm8_t, true, true>; + break; + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + } + else + { + switch(_input->info()->data_type()) + { + case DataType::F32: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false, false> : &NEIm2ColKernel::run_im2col<float, true, false>; + break; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false, false> : &NEIm2ColKernel::run_im2col<float16_t, true, false>; + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::QASYMM8: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<qasymm8_t, false, false> : &NEIm2ColKernel::run_im2col<qasymm8_t, true, false>; + break; + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } } // Configure kernel window diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index 2c9ad923aa..259f4fcb77 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -34,16 +34,12 @@ using namespace arm_compute; namespace { -template <typename T, bool is_nhwc> +template <typename T> void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window) { - DataLayout data_layout = input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const unsigned int kernel_size_x = input->info()->dimension(idx_width); - const unsigned int kernel_size_y = input->info()->dimension(idx_height); - const unsigned int kernel_depth = input->info()->dimension(idx_channel); + const unsigned int kernel_size_x = input->info()->dimension(0); + const unsigned int kernel_size_y = input->info()->dimension(1); + const unsigned int kernel_depth = input->info()->dimension(2); const unsigned int input_stride_x = input->info()->strides_in_bytes().x(); const unsigned int input_stride_y = input->info()->strides_in_bytes().y(); const unsigned int input_stride_z = input->info()->strides_in_bytes().z(); @@ -71,13 +67,13 @@ void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, for(unsigned int i = 0; i < kernel_size_x; ++i) { *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(tmp_input_ptr)); - tmp_input_ptr += is_nhwc ? input_stride_y : input_stride_x; + tmp_input_ptr += input_stride_x; tmp_output_ptr += output_stride_y; } - curr_input_row_ptr += is_nhwc ? input_stride_z : input_stride_y; + curr_input_row_ptr += input_stride_y; tmp_input_ptr = curr_input_row_ptr; } - curr_input_depth_ptr += is_nhwc ? input_stride_x : input_stride_z; + curr_input_depth_ptr += input_stride_z; curr_input_row_ptr = curr_input_depth_ptr; tmp_input_ptr = curr_input_depth_ptr; } @@ -164,24 +160,21 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias _bias = bias; _output = output; - const DataLayout data_layout = input->info()->data_layout(); - const bool is_nhwc = data_layout == DataLayout::NHWC; - switch(_input->info()->element_size()) { case 4: { - _func = is_nhwc ? &weights_reshape<uint32_t, true> : &weights_reshape<uint32_t, false>; + _func = &weights_reshape<uint32_t>; break; } case 2: { - _func = is_nhwc ? &weights_reshape<uint16_t, true> : &weights_reshape<uint16_t, false>; + _func = &weights_reshape<uint16_t>; break; } case 1: { - _func = is_nhwc ? &weights_reshape<uint8_t, true> : &weights_reshape<uint8_t, false>; + _func = &weights_reshape<uint8_t>; break; } default: diff --git a/tests/validation/CL/Im2Col.cpp b/tests/validation/CL/Im2Col.cpp index cf7c79ad72..ebf2331e5e 100644 --- a/tests/validation/CL/Im2Col.cpp +++ b/tests/validation/CL/Im2Col.cpp @@ -94,17 +94,15 @@ template <typename T> using CLIm2ColFixture = Im2ColValidationFixture<CLTensor, CLAccessor, CLIm2Col, T, true>; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -112,16 +110,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode: TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -130,16 +126,14 @@ TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -148,19 +142,17 @@ TEST_SUITE_END() TEST_SUITE(Grouped) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::F32)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::F32)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::F32)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::F32)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -168,19 +160,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode: TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::F16)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::F16)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::F16)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::F16)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -188,19 +178,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode:: TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::QASYMM8)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::QASYMM8)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp index 0ea68bf49d..5a2b46a550 100644 --- a/tests/validation/NEON/Im2Col.cpp +++ b/tests/validation/NEON/Im2Col.cpp @@ -78,16 +78,14 @@ using NEIm2ColFixture = Im2ColValidationFixture<Tensor, Accessor, NEIm2Col, T, f TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -97,16 +95,14 @@ TEST_SUITE_END() #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -118,16 +114,14 @@ TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b5e83a9872..809bafd0b2 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -50,7 +50,7 @@ class Im2ColValidationFixture : public framework::Fixture public: template <typename...> void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, - unsigned int num_groups, bool channels_first_output_nhwc) + unsigned int num_groups) { _kernel_dims = kernel_dims; _conv_info = conv_info; @@ -70,7 +70,7 @@ public: const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); _target = compute_target(input_shape, output_shape, data_type); - compute_reference(input_shape, output_shape, data_type, channels_first_output_nhwc); + compute_reference(input_shape, output_shape, data_type); } protected: @@ -109,7 +109,7 @@ protected: return dst; } - void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type, bool channels_first_output_nhwc) + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) { // Create reference SimpleTensor<T> src{ input_shape, data_type, 1, _quant_info, _data_layout }; @@ -118,7 +118,7 @@ protected: // Fill reference fill(src); - reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups, channels_first_output_nhwc); + reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); } TensorType _target{}; SimpleTensor<T> _reference{}; diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 0c41d88f3e..076b2aba07 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -93,52 +93,6 @@ template <typename T> void im2col_nhwc(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC); - const int pad_x = conv_info.pad().first; - const int pad_y = conv_info.pad().second; - const int stride_x = conv_info.stride().first; - const int stride_y = conv_info.stride().second; - const int kernel_width = kernel_dims.width; - const int kernel_height = kernel_dims.height; - const int src_width = src.shape().y(); - const int src_height = src.shape().z(); - const int src_depth = src.shape().x(); - const int batches = src.shape().total_size_upper(3); - const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; - int dst_idx = 0; - - const int lasty = src_height + (kernel_height > 1 ? pad_y : 0) - kernel_height; - const int lastx = src_width + (kernel_width > 1 ? pad_x : 0) - kernel_width; - - for(int b = 0; b < batches; ++b) - { - for(int y = -pad_y; y <= lasty; y += stride_y) - { - for(int x = -pad_x; x <= lastx; x += stride_x) - { - for(int z = 0; z < src_depth; ++z) - { - for(int patch_y = y; patch_y < (y + kernel_height); ++patch_y) - { - for(int patch_x = x; patch_x < (x + kernel_width); ++patch_x) - { - dst[dst_idx++] = tensor_elem_at(src, Coordinates(z, patch_x, patch_y, b), BorderMode::CONSTANT, static_cast<T>(pad_val)); - } - } - } - - if(has_bias) - { - dst[dst_idx++] = static_cast<T>(1); - } - } - } - } -} - -template <typename T> -void im2col_nhwc_channel_first(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) -{ - ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC); const int stride_x = conv_info.stride().first; const int stride_y = conv_info.stride().second; const int kernel_width = kernel_dims.width; @@ -185,7 +139,7 @@ void im2col_nhwc_channel_first(const SimpleTensor<T> &src, SimpleTensor<T> &dst, } template <typename T> -void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups, bool channels_first_output_nhwc) +void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups) { switch(src.data_layout()) { @@ -196,14 +150,7 @@ void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kern } case DataLayout::NHWC: { - if(channels_first_output_nhwc) - { - im2col_nhwc_channel_first(src, dst, kernel_dims, conv_info, has_bias); - } - else - { - im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias); - } + im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias); break; } default: @@ -214,12 +161,9 @@ void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kern } } -template void im2col(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); -template void im2col(const SimpleTensor<half> &src, SimpleTensor<half> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); -template void im2col(const SimpleTensor<float> &src, SimpleTensor<float> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); +template void im2col(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); +template void im2col(const SimpleTensor<half> &src, SimpleTensor<half> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); +template void im2col(const SimpleTensor<float> &src, SimpleTensor<float> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h index 84ee237453..f519d0e602 100644 --- a/tests/validation/reference/Im2Col.h +++ b/tests/validation/reference/Im2Col.h @@ -35,8 +35,7 @@ namespace validation namespace reference { template <typename T> -void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups, - bool channels_first_output_nhwc = false); +void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups); } // namespace reference } // namespace validation } // namespace test |