From 156fcf3f36f6168e47d65db167bba3af5037e3d9 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Fri, 9 Mar 2018 15:30:43 +0000 Subject: COMPMID-802 Add NHWC data format support for NEON im2col. Change-Id: I86e678179106a2b83d1c6a7cfe562df91b0f9eb2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124000 Tested-by: Jenkins Reviewed-by: Pablo Tello --- arm_compute/core/NEON/kernels/NEIm2ColKernel.h | 2 +- arm_compute/core/utils/misc/ShapeCalculator.h | 42 +++++-- arm_compute/runtime/NEON/functions/NEIm2Col.h | 22 +++- src/core/NEON/kernels/NEIm2ColKernel.cpp | 55 +++++---- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 4 +- .../NEON/functions/NEFullyConnectedLayer.cpp | 4 +- src/runtime/NEON/functions/NEIm2Col.cpp | 25 ++-- tests/validation/NEON/Im2Col.cpp | 68 ++++++++++- tests/validation/fixtures/FlattenLayerFixture.h | 23 ++-- tests/validation/fixtures/Im2ColFixture.h | 132 +++++++++++++++++++++ tests/validation/reference/FlattenLayer.cpp | 16 +-- tests/validation/reference/FlattenLayer.h | 4 +- tests/validation/reference/Im2Col.cpp | 109 +++++++++++++++++ tests/validation/reference/Im2Col.h | 43 +++++++ tests/validation/reference/Permute.cpp | 6 +- 15 files changed, 482 insertions(+), 73 deletions(-) create mode 100644 tests/validation/fixtures/Im2ColFixture.h create mode 100644 tests/validation/reference/Im2Col.cpp create mode 100644 tests/validation/reference/Im2Col.h diff --git a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h index ecfce2436d..5aa803f4fd 100644 --- a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h +++ b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h @@ -111,7 +111,7 @@ public: void run(const Window &window, const ThreadInfo &info) override; private: - /** Template function to run the im2col optimised for the fully connected layer case + /** Template function to run the im2col optimised for the fully connected and flatten layers case * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index c3d5b64a92..e174227302 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -107,13 +107,6 @@ inline TensorShape compute_reductionB_shape(const ITensorInfo &a) return shape_vector_sum_row; } -inline TensorShape compute_im2col_shape(const ITensorInfo &input) -{ - TensorShape shape_im2col{ input.tensor_shape() }; - shape_im2col.collapse(3); - - return shape_im2col; -} inline TensorShape compute_col2im_shape(const ITensorInfo &input, std::pair convolved_dims) { TensorShape col2im_shape{ input.tensor_shape() }; @@ -159,7 +152,25 @@ inline TensorShape compute_deconvolution_shape(const ITensorInfo &input, unsigne return scale_out_shape; } -inline TensorShape compute_im2col_shape(const ITensorInfo *input, const int num_input_dimensions = 3) +inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation) +{ + // The output shape will be the 2D shape used as input for GEMM [ out_channels * kernel_area, num_elems_per_out_channel ] + + TensorShape output_shape{ input->tensor_shape() }; + + const DataLayout data_layout = input->data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + + std::pair out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation); + output_shape.set(width_idx, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0))); + output_shape.set(height_idx, (out_dims.first * out_dims.second)); + output_shape.set(channel_idx, 1); + + return output_shape; +} +inline TensorShape compute_im2col_fc_shape(const ITensorInfo *input, const int num_input_dimensions = 3) { TensorShape output_shape{ input->tensor_shape() }; @@ -167,6 +178,21 @@ inline TensorShape compute_im2col_shape(const ITensorInfo *input, const int num_ return output_shape; } +inline TensorShape compute_im2col_flatten_shape(const ITensorInfo *input) +{ + // The output shape will be the flatten version of the input (i.e. [ width * height * channels, 1, 1, ... ] ). Used for FlattenLayer. + + ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 3); + + TensorShape output_shape{ input->tensor_shape() }; + + const size_t flatten_shape = input->dimension(0) * input->dimension(1) * input->dimension(2); + output_shape.set(0, flatten_shape); + output_shape.remove_dimension(1); + output_shape.remove_dimension(1); + + return output_shape; +} inline TensorShape compute_interleave_custom_shape(const TensorShape &input, const int x_interleave, const int y_interleave) { TensorShape output_shape{ input }; diff --git a/arm_compute/runtime/NEON/functions/NEIm2Col.h b/arm_compute/runtime/NEON/functions/NEIm2Col.h index cf4999b5af..caa8a011f6 100644 --- a/arm_compute/runtime/NEON/functions/NEIm2Col.h +++ b/arm_compute/runtime/NEON/functions/NEIm2Col.h @@ -26,6 +26,7 @@ #include "arm_compute/runtime/NEON/INESimpleFunction.h" +#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h" #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Types.h" @@ -34,9 +35,11 @@ namespace arm_compute class ITensor; /** Basic function to run @ref NEIm2ColKernel */ -class NEIm2Col : public INESimpleFunction +class NEIm2Col : public IFunction { public: + /** Default constructor */ + NEIm2Col(); /** Configure the im2col NEON kernel * * @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM], @@ -46,9 +49,10 @@ public: * @param[in] kernel_dims The kernel dimensions (width and height). * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. * @param[in] has_bias In case biases are provided expands the matrix with 1. - * @param[in] is_fully_connected Determines whether this kernel will be called by @ref NEFullyConnectedLayer in order to validate the arguments + * @param[in] is_fully_connected (Optional) Determines whether this function will be called by @ref NEFullyConnectedLayer in order to validate the arguments + * @param[in] is_flatten (Optional) Determines whether this function will be called by @ref NEFlattenLayer in order to validate the arguments */ - void configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected = false); + void configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected = false, bool is_flatten = false); /** Static function to check if given info will lead to a valid configuration of @ref NEIm2Col * * @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM], @@ -58,11 +62,19 @@ public: * @param[in] kernel_dims The kernel dimensions (width and height). * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. * @param[in] has_bias In case biases are provided expands the matrix with 1. - * @param[in] is_fully_connected Determines whether this kernel will be called by @ref NEFullyConnectedLayer in order to validate the arguments + * @param[in] is_fully_connected Determines whether this function will be called by @ref NEFullyConnectedLayer in order to validate the arguments + * @param[in] is_flatten Determines whether this function will be called by @ref NEFlattenLayer in order to validate the arguments * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected); + static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten); + + // Inherited methods overridden: + void run() override; + +private: + NEIm2ColKernel _kernel; + unsigned int _y_dim; }; } #endif /* __ARM_COMPUTE_NEIM2COL_H__ */ diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index 348722c55d..5e165a641c 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -53,27 +53,26 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias); ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || (dilation.y() < 1)); + TensorShape expected_output_shape; if(is_flatten) /* Called by FlattenLayer */ { - size_t flatten_shape = input->tensor_shape().x() * input->tensor_shape().y() * input->tensor_shape().z(); - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != flatten_shape); + expected_output_shape = misc::shape_calculator::compute_im2col_flatten_shape(input); } else if(!is_fully_connected) /* Called by ConvolutionLayer */ { - std::pair out_dims = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_dims.width, kernel_dims.height, conv_info, dilation); - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(2) * kernel_dims.area() + (has_bias ? 1 : 0))); - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != (out_dims.first * out_dims.second)); - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(2) != 1); + expected_output_shape = misc::shape_calculator::compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation); } else /* Called by FullyConnectedLayer */ { const int num_batch_dimensions = std::max(0, static_cast(output->tensor_shape().num_dimensions()) - 1); const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions; - TensorInfo expected_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_im2col_shape(input, num_input_dimensions)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output); + expected_output_shape = misc::shape_calculator::compute_im2col_fc_shape(input, num_input_dimensions); } + TensorInfo expected_output = output->clone()->set_tensor_shape(expected_output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output); + return Status{}; } @@ -194,12 +193,17 @@ void NEIm2ColKernel::run_generic(const Window &window) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - const int kernel_depth = _input->info()->dimension(2); - const int input_w = _input->info()->dimension(0); - const int input_h = _input->info()->dimension(1); - const int input_stride_x = _input->info()->strides_in_bytes().x(); - const int input_stride_y = _input->info()->strides_in_bytes().y(); - const int input_stride_z = _input->info()->strides_in_bytes().z(); + const DataLayout data_layout = _input->info()->data_layout(); + const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + + const int kernel_depth = _input->info()->dimension(channel_idx); + const int input_w = _input->info()->dimension(width_idx); + const int input_h = _input->info()->dimension(height_idx); + const int input_stride_x = _input->info()->strides_in_bytes()[width_idx]; + const int input_stride_y = _input->info()->strides_in_bytes()[height_idx]; + const int input_stride_z = _input->info()->strides_in_bytes()[channel_idx]; const int offset = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().offset : 0; int pad_left = 0; @@ -222,9 +226,9 @@ void NEIm2ColKernel::run_generic(const Window &window) // Setup output window Window window_out(window); - window_out.set(Window::DimX, Window::Dimension(0, _output->info()->dimension(0), _output->info()->strides_in_bytes().y() / _output->info()->element_size())); - window_out.set(Window::DimY, Window::Dimension(window.y().start() * _convolved_dims.first, window.y().end() * _convolved_dims.first, _convolved_dims.first)); - window_out.set(Window::DimZ, Window::Dimension(0, 1, 1)); + window_out.set(width_idx, Window::Dimension(0, _output->info()->dimension(width_idx), _output->info()->strides_in_bytes()[width_idx + 1] / _output->info()->strides_in_bytes()[width_idx])); + window_out.set(height_idx, Window::Dimension(window[height_idx].start() * _convolved_dims.first, window[height_idx].end() * _convolved_dims.first, _convolved_dims.first)); + window_out.set(channel_idx, Window::Dimension(0, 1, 1)); // Create iterators Iterator in(_input, window_in); @@ -232,8 +236,8 @@ void NEIm2ColKernel::run_generic(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { - const int top_left_x = id.x() * stride_x + start_x; - const int top_left_y = id.y() * stride_y + start_y; + const int top_left_x = id[width_idx] * stride_x + start_x; + const int top_left_y = id[height_idx] * stride_y + start_y; // Get pointers const uint8_t *const input_ptr = in.ptr(); @@ -327,13 +331,18 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size ARM_COMPUTE_UNUSED(is_fully_connected, is_flatten); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten, dilation)); + const DataLayout data_layout = input->info()->data_layout(); + const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + _input = input; _output = output; _conv_info = conv_info; _kernel_width = kernel_dims.width; _kernel_height = kernel_dims.height; _dilation = dilation; - _convolved_dims = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), + _convolved_dims = scaled_dimensions(input->info()->dimension(width_idx), input->info()->dimension(height_idx), _kernel_width, _kernel_height, _conv_info, _dilation); _has_bias = has_bias; @@ -402,9 +411,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size ARM_COMPUTE_ERROR("Data type not supported"); break; } - window.set(Window::DimX, Window::Dimension(0, _convolved_dims.first, 1)); - window.set(Window::DimY, Window::Dimension(0, _convolved_dims.second, 1)); - window.set(Window::DimZ, Window::Dimension(0, 1, 1)); + window.set(width_idx, Window::Dimension(0, _convolved_dims.first, 1)); + window.set(height_idx, Window::Dimension(0, _convolved_dims.second, 1)); + window.set(channel_idx, Window::Dimension(0, 1, 1)); } // The NEIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 676706fb17..5dd1f00fa6 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -114,7 +114,7 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT // If the fully connected layer is called after a convolution layer, the input tensor must be linearized // Initialize output tensor for im2col - TensorShape shape_im2col = compute_im2col_shape(input->info()); + TensorShape shape_im2col = compute_im2col_fc_shape(input->info()); _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); // Configure im2col kernel @@ -244,7 +244,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); const GPUTarget gpu_target = CLScheduler::get().target(); - const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_shape(input))); + const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input))); const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights))); const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32)); diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index b310ad35e3..958d081fd2 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -188,7 +188,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh if(_linearize_input) { - _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_shape(input->info(), num_input_dimensions))); + _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input->info(), num_input_dimensions))); // Configure im2col kernel _memory_group.manage(&_im2col_output); @@ -288,7 +288,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn if(linearize_input) { - im2col_output->set_tensor_shape(compute_im2col_shape(input, num_input_dimensions)); + im2col_output->set_tensor_shape(compute_im2col_fc_shape(input, num_input_dimensions)); ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, im2col_output.get(), Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true)); diff --git a/src/runtime/NEON/functions/NEIm2Col.cpp b/src/runtime/NEON/functions/NEIm2Col.cpp index b962db9144..6b95cb0256 100644 --- a/src/runtime/NEON/functions/NEIm2Col.cpp +++ b/src/runtime/NEON/functions/NEIm2Col.cpp @@ -23,19 +23,30 @@ */ #include "arm_compute/runtime/NEON/functions/NEIm2Col.h" -#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" #include "support/ToolchainSupport.h" using namespace arm_compute; -void NEIm2Col::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected) +NEIm2Col::NEIm2Col() + : _kernel(), _y_dim(1) { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, output, kernel_dims, conv_info, has_bias, is_fully_connected); - _kernel = std::move(k); } -Status NEIm2Col::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected) +void NEIm2Col::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten) { - return NEIm2ColKernel::validate(input, output, kernel_dims, conv_info, has_bias, is_fully_connected); + _y_dim = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT); + + _kernel.configure(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten); +} + +Status NEIm2Col::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten) +{ + return NEIm2ColKernel::validate(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten); +} + +void NEIm2Col::run() +{ + NEScheduler::get().schedule(&_kernel, _y_dim); } diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp index 96dd6f86ab..ce00128afa 100644 --- a/tests/validation/NEON/Im2Col.cpp +++ b/tests/validation/NEON/Im2Col.cpp @@ -23,10 +23,13 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEIm2Col.h" +#include "tests/NEON/Accessor.h" +#include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" #include "tests/validation/Validation.h" +#include "tests/validation/fixtures/Im2ColFixture.h" namespace arm_compute { @@ -34,6 +37,12 @@ namespace test { namespace validation { +namespace +{ +const auto conv_args = combine(combine(combine(framework::dataset::make("KernelDims", { Size2D(3U, 3U), Size2D(5U, 5U) }), framework::dataset::make("PadStride", { PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(1U, 1U, 1U, 1U), PadStrideInfo(2U, 2U, 0U, 2U) })), + framework::dataset::make("QuantizationInfo", QuantizationInfo(0.5f, 10))), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +} // namespace TEST_SUITE(NEON) TEST_SUITE(Im2Col) @@ -45,7 +54,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), // Bias not supported with QASYMM8 TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), // Mismatching shapes - TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), + TensorInfo(TensorShape(10U, 12U, 2U, 2U), 1, DataType::QASYMM8), }), framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16), TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16), @@ -58,12 +67,67 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Expected", { false, false, false, false, false, true })), input_info, output_info, has_bias, expected) { - bool status = bool(NEIm2Col::validate(&input_info, &output_info, Size2D(3U, 3U), PadStrideInfo(), has_bias, false)); + bool status = bool(NEIm2Col::validate(&input_info, &output_info, Size2D(3U, 3U), PadStrideInfo(), has_bias, false, false)); ARM_COMPUTE_EXPECT(status == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* +template +using NEIm2ColFixture = Im2ColValidationFixture; + +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +TEST_SUITE_END() + +TEST_SUITE(QASYMM8) +FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + TEST_SUITE_END() TEST_SUITE_END() } // namespace validation diff --git a/tests/validation/fixtures/FlattenLayerFixture.h b/tests/validation/fixtures/FlattenLayerFixture.h index 3de0ba45ae..ef94ea83b0 100644 --- a/tests/validation/fixtures/FlattenLayerFixture.h +++ b/tests/validation/fixtures/FlattenLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/Tensor.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -43,6 +44,8 @@ namespace test { namespace validation { +using namespace arm_compute::misc::shape_calculator; + template class FlattenLayerValidationFixture : public framework::Fixture { @@ -51,8 +54,13 @@ public: void setup(TensorShape shape, DataType data_type) { _fractional_bits = is_data_type_fixed_point(data_type) ? 4 : 0; - _target = compute_target(shape, data_type); - _reference = compute_reference(shape, data_type); + + TensorShape shape_flatten; + TensorInfo input_info(shape, 1, data_type, _fractional_bits); + shape_flatten = compute_im2col_flatten_shape(&input_info); + + _target = compute_target(shape, shape_flatten, data_type); + _reference = compute_reference(shape, shape_flatten, data_type); ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(_target.info()->tensor_shape(), _reference.shape()); } @@ -73,11 +81,8 @@ protected: } } - TensorType compute_target(const TensorShape &shape, DataType data_type) + TensorType compute_target(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type) { - TensorShape shape_flatten(shape); - shape_flatten.collapse(3); - // Create tensors TensorType src = create_tensor(shape, data_type, 1, _fractional_bits); TensorType dst = create_tensor(shape_flatten, data_type, 1, _fractional_bits); @@ -105,7 +110,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType data_type) + SimpleTensor compute_reference(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type) { // Create reference SimpleTensor src{ shape, data_type, 1, _fractional_bits }; @@ -113,7 +118,7 @@ protected: // Fill reference fill(src); - return reference::flatten_layer(src); + return reference::flatten_layer(src, shape_flatten); } TensorType _target{}; diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h new file mode 100644 index 0000000000..f403aa9d21 --- /dev/null +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_TEST_IM2COL_FIXTURE +#define ARM_COMPUTE_TEST_IM2COL_FIXTURE + +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/Tensor.h" +#include "tests/AssetsLibrary.h" +#include "tests/Globals.h" +#include "tests/IAccessor.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Fixture.h" +#include "tests/validation/reference/Im2Col.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::misc::shape_calculator; + +template +class Im2ColValidationFixture : public framework::Fixture +{ +public: + template + void setup(TensorShape shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout) + { + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _quant_info = quant_info; + _data_layout = data_layout; + _has_bias = data_type != DataType::QASYMM8; + + if(_data_layout == DataLayout::NHWC) + { + permute(shape, PermutationVector(2U, 0U, 1U)); + } + + TensorShape output_shape; + TensorInfo input_info(shape, 1, data_type); + input_info.set_data_layout(_data_layout); + output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U)); + + _target = compute_target(shape, output_shape, data_type); + _reference = compute_reference(shape, output_shape, data_type); + } + +protected: + template + void fill(U &&tensor) + { + library->fill_tensor_uniform(tensor, 0); + } + + TensorType compute_target(const TensorShape &shape, const TensorShape &output_shape, DataType data_type) + { + // Create tensors + TensorType src = create_tensor(shape, data_type, 1, 0, _quant_info, _data_layout); + TensorType dst = create_tensor(output_shape, data_type, 1, 0, _quant_info, _data_layout); + + // Create and configure function + FunctionType im2col_func; + im2col_func.configure(&src, &dst, _kernel_dims, _conv_info, _has_bias); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(src)); + + // Compute function + im2col_func.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &shape, const TensorShape &output_shape, DataType data_type) + { + // Create reference + SimpleTensor src{ shape, data_type, 1, 0, _quant_info, _data_layout }; + + // Fill reference + fill(src); + + return reference::im2col(src, output_shape, _kernel_dims, _conv_info, _has_bias); + } + + TensorType _target{}; + SimpleTensor _reference{}; + Size2D _kernel_dims{}; + PadStrideInfo _conv_info{}; + DataLayout _data_layout{}; + QuantizationInfo _quant_info{}; + bool _has_bias{}; +}; +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* ARM_COMPUTE_TEST_IM2COL_FIXTURE */ diff --git a/tests/validation/reference/FlattenLayer.cpp b/tests/validation/reference/FlattenLayer.cpp index 611701d8cf..44f4d93178 100644 --- a/tests/validation/reference/FlattenLayer.cpp +++ b/tests/validation/reference/FlattenLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,12 +34,8 @@ namespace validation namespace reference { template -SimpleTensor flatten_layer(const SimpleTensor &src) +SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten) { - TensorShape shape_flatten(src.shape()); - shape_flatten.set(0, src.shape()[0] * src.shape()[1] * src.shape()[2]); - shape_flatten.remove_dimension(1); - shape_flatten.remove_dimension(1); SimpleTensor dst(shape_flatten, src.data_type(), 1, src.fixed_point_position()); // Note: Since the reference implementation does not use padding bytes, we can copy directly the content of the source tensor @@ -48,10 +44,10 @@ SimpleTensor flatten_layer(const SimpleTensor &src) return dst; } -template SimpleTensor flatten_layer(const SimpleTensor &src); -template SimpleTensor flatten_layer(const SimpleTensor &src); -template SimpleTensor flatten_layer(const SimpleTensor &src); -template SimpleTensor flatten_layer(const SimpleTensor &src); +template SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten); +template SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten); +template SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten); +template SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/FlattenLayer.h b/tests/validation/reference/FlattenLayer.h index b1286fe2bd..5ccd429e3b 100644 --- a/tests/validation/reference/FlattenLayer.h +++ b/tests/validation/reference/FlattenLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,7 @@ namespace validation namespace reference { template -SimpleTensor flatten_layer(const SimpleTensor &src); +SimpleTensor flatten_layer(const SimpleTensor &src, const TensorShape &shape_flatten); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp new file mode 100644 index 0000000000..825f0a6ee1 --- /dev/null +++ b/tests/validation/reference/Im2Col.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "Im2Col.h" + +#include "Permute.h" + +#include "arm_compute/core/Types.h" +#include "tests/validation/Helpers.h" +#include "tests/validation/reference/Utils.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace reference +{ +template +void im2col_nchw(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +{ + // Create reference + const int pad_x = conv_info.pad().first; + const int pad_y = conv_info.pad().second; + const int stride_x = conv_info.stride().first; + const int stride_y = conv_info.stride().second; + const int kernel_width = kernel_dims.width; + const int kernel_height = kernel_dims.height; + const int src_width = src.shape().x(); + const int src_height = src.shape().y(); + const int src_depth = src.shape().z(); + const int batches = src.shape().total_size_upper(3); + const int pad_val = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0; + + int dst_idx = 0; + for(int b = 0; b < batches; ++b) + { + for(int y = -pad_y; y <= (src_height + pad_y - kernel_height); y += stride_y) + { + for(int x = -pad_x; x <= (src_width + pad_x - kernel_width); x += stride_x) + { + for(int z = 0; z < src_depth; ++z) + { + for(int patch_y = y; patch_y < (y + kernel_height); ++patch_y) + { + for(int patch_x = x; patch_x < (x + kernel_width); ++patch_x) + { + dst[dst_idx++] = tensor_elem_at(src, Coordinates(patch_x, patch_y, z, b), BorderMode::CONSTANT, static_cast(pad_val)); + } + } + } + + if(has_bias) + { + dst[dst_idx++] = static_cast(1); + } + } + } + } +} + +template +SimpleTensor im2col(const SimpleTensor &src, const TensorShape &dst_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +{ + SimpleTensor dst{ dst_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; + + if(src.data_layout() == DataLayout::NHWC) + { + SimpleTensor src_nchw = reference::permute(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor dst_nchw = reference::permute(dst, PermutationVector(1U, 2U, 0U)); + + im2col_nchw(src_nchw, dst_nchw, kernel_dims, conv_info, has_bias); + + return reference::permute(dst_nchw, PermutationVector(2U, 0U, 1U)); + } + + im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); + + return dst; +} + +template SimpleTensor im2col(const SimpleTensor &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +template SimpleTensor im2col(const SimpleTensor &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +template SimpleTensor im2col(const SimpleTensor &src, const TensorShape &output_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +} // namespace reference +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h new file mode 100644 index 0000000000..4fe6ea9acf --- /dev/null +++ b/tests/validation/reference/Im2Col.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_TEST_IM2COL_H__ +#define __ARM_COMPUTE_TEST_IM2COL_H__ + +#include "tests/SimpleTensor.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace reference +{ +template +SimpleTensor im2col(const SimpleTensor &src, const TensorShape &dst_shape, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias); +} // namespace reference +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* __ARM_COMPUTE_TEST_IM2COL_H__ */ diff --git a/tests/validation/reference/Permute.cpp b/tests/validation/reference/Permute.cpp index 4a12ca6959..db347e51f5 100644 --- a/tests/validation/reference/Permute.cpp +++ b/tests/validation/reference/Permute.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,7 +42,7 @@ SimpleTensor permute(const SimpleTensor &src, PermutationVector perm) permute(dst_shape, perm); // Create reference - SimpleTensor dst{ dst_shape, src.data_type() }; + SimpleTensor dst{ dst_shape, src.data_type(), src.num_channels(), src.fixed_point_position(), src.quantization_info() }; // Compute reference for(int i = 0; i < src.num_elements(); ++i) @@ -60,6 +60,8 @@ SimpleTensor permute(const SimpleTensor &src, PermutationVector perm) template SimpleTensor permute(const SimpleTensor &src, PermutationVector perm); template SimpleTensor permute(const SimpleTensor &src, PermutationVector perm); template SimpleTensor permute(const SimpleTensor &src, PermutationVector perm); +template SimpleTensor permute(const SimpleTensor &src, PermutationVector perm); +template SimpleTensor permute(const SimpleTensor &src, PermutationVector perm); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1