From fb62908bd8148bd347bd204e881156f8ebf7835d Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 20 Aug 2018 18:03:27 +0100 Subject: COMPMID-1494 Optimise NEON im2col and weights reshape for NHWC Change-Id: I99ebae61024a7bce9d17292a02c28626ae6c29d5 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144872 Tested-by: Jenkins Reviewed-by: Gian Marco Iodice --- arm_compute/core/NEON/kernels/NEIm2ColKernel.h | 4 +- src/core/NEON/kernels/NEIm2ColKernel.cpp | 219 ++++++++++++++++------- src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 27 ++- tests/validation/CL/Im2Col.cpp | 72 ++++---- tests/validation/NEON/Im2Col.cpp | 30 ++-- tests/validation/fixtures/Im2ColFixture.h | 8 +- tests/validation/reference/Im2Col.cpp | 66 +------ tests/validation/reference/Im2Col.h | 3 +- 8 files changed, 218 insertions(+), 211 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h index ec89f0f713..f76521f770 100644 --- a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h +++ b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h @@ -109,11 +109,11 @@ public: void run(const Window &window, const ThreadInfo &info) override; private: - /** Template function to run the im2col + /** Template function to run im2col * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template + template void run_im2col(const Window &window); /** Common signature for all the specialised im2col functions diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index e5d31289a4..2c51eae468 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -90,22 +90,22 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } template -inline void linearize_volume(const uint8_t *const in_ptr, - T *out_ptr, - bool has_bias, - int top_left_x, - int top_left_y, - int kernel_width, - int kernel_height, - int kernel_depth, - int input_w, - int input_h, - int input_stride_x, - int input_stride_y, - int input_stride_z, - int pad_value, - int dilation_x, - int dilation_y) +inline void linearize_volume_nchw(const uint8_t *const in_ptr, + T *out_ptr, + bool has_bias, + int top_left_x, + int top_left_y, + int kernel_width, + int kernel_height, + int kernel_depth, + int input_w, + int input_h, + int input_stride_x, + int input_stride_y, + int input_stride_z, + int pad_value, + int dilation_x, + int dilation_y) { const int kernel_size2 = kernel_width * kernel_height; const int x_e = top_left_x + kernel_width * dilation_x; @@ -186,9 +186,62 @@ inline void linearize_volume(const uint8_t *const in_ptr, *out_ptr = static_cast(1); } } -} // namespace template +inline void linearize_volume_nhwc(const uint8_t *const in_ptr, + T *out_ptr, + bool has_bias, + int start_x, + int start_y, + int kernel_width, + int kernel_height, + int input_w, + int input_h, + int input_c, + int input_stride_y, + int input_stride_z, + int pad_value, + int dilation_x, + int dilation_y) +{ + const int end_x = start_x + kernel_width * dilation_x; + const int end_y = start_y + kernel_height * dilation_y; + const int pad_quant = kernel_width * input_c; + + for(int y = start_y; y < end_y; y += dilation_y) + { + if(y < 0 || y >= input_h) + { + memset(out_ptr, pad_value, pad_quant * sizeof(T)); + out_ptr += pad_quant; + } + else + { + for(int x = start_x; x < end_x; x += dilation_x) + { + if(x < 0 || x >= input_w) + { + memset(out_ptr, pad_value, input_c * sizeof(T)); + out_ptr += input_c; + } + else + { + memcpy(out_ptr, reinterpret_cast(in_ptr + (y * input_stride_z + x * input_stride_y)), input_c * sizeof(T)); + out_ptr += input_c; + } + } + } + } + + // Append 1 if the convolution layer has biases + if(has_bias) + { + *out_ptr = static_cast(1); + } +} +} // namespace + +template void NEIm2ColKernel::run_im2col(const Window &window) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); @@ -199,25 +252,17 @@ void NEIm2ColKernel::run_im2col(const Window &window) const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const int kernel_depth = _input->info()->dimension(channel_idx); const int input_w = _input->info()->dimension(width_idx); const int input_h = _input->info()->dimension(height_idx); - const int input_stride_x = _input->info()->strides_in_bytes()[width_idx]; - const int input_stride_y = _input->info()->strides_in_bytes()[height_idx]; - const int input_stride_z = _input->info()->strides_in_bytes()[channel_idx]; - const int offset = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().offset : 0; - - int pad_left = 0; - int pad_top = 0; - int stride_x = 0; - int stride_y = 0; - pad_left = _conv_info.pad_left(); - pad_top = _conv_info.pad_top(); - std::tie(stride_x, stride_y) = _conv_info.stride(); - - // Setup input window - const int start_x = -pad_left; - const int start_y = -pad_top; + const int input_c = _input->info()->dimension(channel_idx); + const int input_stride_x = _input->info()->strides_in_bytes().x(); + const int input_stride_y = _input->info()->strides_in_bytes().y(); + const int input_stride_z = _input->info()->strides_in_bytes().z(); + const int pad_left = _conv_info.pad_left(); + const int pad_top = _conv_info.pad_top(); + const int stride_x = _conv_info.stride().first; + const int stride_y = _conv_info.stride().second; + const int pad_value = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().offset : 0; Window window_in_out(window); // The first three dimensions of the input and output are increased by the inner loops @@ -231,30 +276,51 @@ void NEIm2ColKernel::run_im2col(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { - const int top_left_x = id[width_idx] * stride_x + start_x; - const int top_left_y = id[height_idx] * stride_y + start_y; + const int start_w = id[width_idx] * stride_x - pad_left; + const int start_h = id[height_idx] * stride_y - pad_top; // Get pointers const uint8_t *const input_ptr = in.ptr(); auto output_ptr = reinterpret_cast(out.ptr() + (id[width_idx] + id[height_idx] * _convolved_dims.first) * _output->info()->strides_in_bytes().y()); // Linearize volume - linearize_volume(input_ptr, - output_ptr, - _has_bias, - top_left_x, - top_left_y, - static_cast(_kernel_width), - static_cast(_kernel_height), - kernel_depth, - input_w, - input_h, - input_stride_x, - input_stride_y, - input_stride_z, - offset, - _dilation.x(), - _dilation.y()); + if(is_nchw) + { + linearize_volume_nchw(input_ptr, + output_ptr, + _has_bias, + start_w, + start_h, + _kernel_width, + _kernel_height, + input_c, + input_w, + input_h, + input_stride_x, + input_stride_y, + input_stride_z, + pad_value, + _dilation.x(), + _dilation.y()); + } + else + { + linearize_volume_nhwc(input_ptr, + output_ptr, + _has_bias, + start_w, + start_h, + _kernel_width, + _kernel_height, + input_w, + input_h, + input_c, + input_stride_y, + input_stride_z, + pad_value, + _dilation.x(), + _dilation.y()); + } }, in, out); } @@ -286,22 +352,45 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size _conv_info, _dilation); _has_bias = has_bias; - switch(_input->info()->data_type()) + if(data_layout == DataLayout::NCHW) { - case DataType::F32: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; - break; + switch(_input->info()->data_type()) + { + case DataType::F32: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; - break; + case DataType::F16: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::QASYMM8: - _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; - break; - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; + case DataType::QASYMM8: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + } + else + { + switch(_input->info()->data_type()) + { + case DataType::F32: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::QASYMM8: + _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col : &NEIm2ColKernel::run_im2col; + break; + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } } // Configure kernel window diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index 2c9ad923aa..259f4fcb77 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -34,16 +34,12 @@ using namespace arm_compute; namespace { -template +template void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window) { - DataLayout data_layout = input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const unsigned int kernel_size_x = input->info()->dimension(idx_width); - const unsigned int kernel_size_y = input->info()->dimension(idx_height); - const unsigned int kernel_depth = input->info()->dimension(idx_channel); + const unsigned int kernel_size_x = input->info()->dimension(0); + const unsigned int kernel_size_y = input->info()->dimension(1); + const unsigned int kernel_depth = input->info()->dimension(2); const unsigned int input_stride_x = input->info()->strides_in_bytes().x(); const unsigned int input_stride_y = input->info()->strides_in_bytes().y(); const unsigned int input_stride_z = input->info()->strides_in_bytes().z(); @@ -71,13 +67,13 @@ void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, for(unsigned int i = 0; i < kernel_size_x; ++i) { *(reinterpret_cast(tmp_output_ptr)) = *(reinterpret_cast(tmp_input_ptr)); - tmp_input_ptr += is_nhwc ? input_stride_y : input_stride_x; + tmp_input_ptr += input_stride_x; tmp_output_ptr += output_stride_y; } - curr_input_row_ptr += is_nhwc ? input_stride_z : input_stride_y; + curr_input_row_ptr += input_stride_y; tmp_input_ptr = curr_input_row_ptr; } - curr_input_depth_ptr += is_nhwc ? input_stride_x : input_stride_z; + curr_input_depth_ptr += input_stride_z; curr_input_row_ptr = curr_input_depth_ptr; tmp_input_ptr = curr_input_depth_ptr; } @@ -164,24 +160,21 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias _bias = bias; _output = output; - const DataLayout data_layout = input->info()->data_layout(); - const bool is_nhwc = data_layout == DataLayout::NHWC; - switch(_input->info()->element_size()) { case 4: { - _func = is_nhwc ? &weights_reshape : &weights_reshape; + _func = &weights_reshape; break; } case 2: { - _func = is_nhwc ? &weights_reshape : &weights_reshape; + _func = &weights_reshape; break; } case 1: { - _func = is_nhwc ? &weights_reshape : &weights_reshape; + _func = &weights_reshape; break; } default: diff --git a/tests/validation/CL/Im2Col.cpp b/tests/validation/CL/Im2Col.cpp index cf7c79ad72..ebf2331e5e 100644 --- a/tests/validation/CL/Im2Col.cpp +++ b/tests/validation/CL/Im2Col.cpp @@ -94,17 +94,15 @@ template using CLIm2ColFixture = Im2ColValidationFixture; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -112,16 +110,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode: TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -130,16 +126,14 @@ TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -148,19 +142,17 @@ TEST_SUITE_END() TEST_SUITE(Grouped) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::F32)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::F32)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::F32)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::F32)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -168,19 +160,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode: TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::F16)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::F16)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::F16)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::F16)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); @@ -188,19 +178,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode:: TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", - DataType::QASYMM8)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", - DataType::QASYMM8)), - grouped_args), - framework::dataset::make("ChannelsFirstOutputNHWC", true))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), + grouped_args)) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp index 0ea68bf49d..5a2b46a550 100644 --- a/tests/validation/NEON/Im2Col.cpp +++ b/tests/validation/NEON/Im2Col.cpp @@ -78,16 +78,14 @@ using NEIm2ColFixture = Im2ColValidationFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -97,16 +95,14 @@ TEST_SUITE_END() #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -118,16 +114,14 @@ TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args), - framework::dataset::make("ChannelsFirstOutputNHWC", false))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b5e83a9872..809bafd0b2 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -50,7 +50,7 @@ class Im2ColValidationFixture : public framework::Fixture public: template void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, - unsigned int num_groups, bool channels_first_output_nhwc) + unsigned int num_groups) { _kernel_dims = kernel_dims; _conv_info = conv_info; @@ -70,7 +70,7 @@ public: const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); _target = compute_target(input_shape, output_shape, data_type); - compute_reference(input_shape, output_shape, data_type, channels_first_output_nhwc); + compute_reference(input_shape, output_shape, data_type); } protected: @@ -109,7 +109,7 @@ protected: return dst; } - void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type, bool channels_first_output_nhwc) + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) { // Create reference SimpleTensor src{ input_shape, data_type, 1, _quant_info, _data_layout }; @@ -118,7 +118,7 @@ protected: // Fill reference fill(src); - reference::im2col(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups, channels_first_output_nhwc); + reference::im2col(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); } TensorType _target{}; SimpleTensor _reference{}; diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 0c41d88f3e..076b2aba07 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -91,52 +91,6 @@ void im2col_nchw(const SimpleTensor &src, SimpleTensor &dst, const Size2D template void im2col_nhwc(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) -{ - ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC); - const int pad_x = conv_info.pad().first; - const int pad_y = conv_info.pad().second; - const int stride_x = conv_info.stride().first; - const int stride_y = conv_info.stride().second; - const int kernel_width = kernel_dims.width; - const int kernel_height = kernel_dims.height; - const int src_width = src.shape().y(); - const int src_height = src.shape().z(); - const int src_depth = src.shape().x(); - const int batches = src.shape().total_size_upper(3); - const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; - int dst_idx = 0; - - const int lasty = src_height + (kernel_height > 1 ? pad_y : 0) - kernel_height; - const int lastx = src_width + (kernel_width > 1 ? pad_x : 0) - kernel_width; - - for(int b = 0; b < batches; ++b) - { - for(int y = -pad_y; y <= lasty; y += stride_y) - { - for(int x = -pad_x; x <= lastx; x += stride_x) - { - for(int z = 0; z < src_depth; ++z) - { - for(int patch_y = y; patch_y < (y + kernel_height); ++patch_y) - { - for(int patch_x = x; patch_x < (x + kernel_width); ++patch_x) - { - dst[dst_idx++] = tensor_elem_at(src, Coordinates(z, patch_x, patch_y, b), BorderMode::CONSTANT, static_cast(pad_val)); - } - } - } - - if(has_bias) - { - dst[dst_idx++] = static_cast(1); - } - } - } - } -} - -template -void im2col_nhwc_channel_first(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC); const int stride_x = conv_info.stride().first; @@ -185,7 +139,7 @@ void im2col_nhwc_channel_first(const SimpleTensor &src, SimpleTensor &dst, } template -void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups, bool channels_first_output_nhwc) +void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups) { switch(src.data_layout()) { @@ -196,14 +150,7 @@ void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kern } case DataLayout::NHWC: { - if(channels_first_output_nhwc) - { - im2col_nhwc_channel_first(src, dst, kernel_dims, conv_info, has_bias); - } - else - { - im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias); - } + im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias); break; } default: @@ -214,12 +161,9 @@ void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kern } } -template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); -template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); -template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups, - bool channels_first_output_nhwc); +template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); +template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); +template void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h index 84ee237453..f519d0e602 100644 --- a/tests/validation/reference/Im2Col.h +++ b/tests/validation/reference/Im2Col.h @@ -35,8 +35,7 @@ namespace validation namespace reference { template -void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups, - bool channels_first_output_nhwc = false); +void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1