From 19ea419e7f14d02aeb208c2fbd5a4ac55f4cb101 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 19 Jun 2018 13:09:53 +0100 Subject: COMPMID-809: Add NHWC data format on CLGEMMConvolutionLayer. Change-Id: I50e4f5e7d47e21c300f754bee2c216863075b5cf Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/136191 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Gian Marco Iodice --- arm_compute/core/TensorShape.h | 14 + arm_compute/core/utils/misc/ShapeCalculator.h | 1 + .../runtime/CL/functions/CLGEMMConvolutionLayer.h | 27 +- src/core/CL/CLKernelLibrary.cpp | 3 +- src/core/CL/cl_kernels/col2im.cl | 12 +- src/core/CL/cl_kernels/convolution_layer.cl | 72 ++++- src/core/CL/cl_kernels/im2col.cl | 9 +- src/core/CL/kernels/CLCol2ImKernel.cpp | 16 +- src/core/CL/kernels/CLIm2ColKernel.cpp | 16 +- src/core/CL/kernels/CLWeightsReshapeKernel.cpp | 6 +- .../CL/functions/CLGEMMConvolutionLayer.cpp | 305 ++++++++++++++------- .../CL/functions/CLLocallyConnectedLayer.cpp | 10 +- tests/validation/CL/ConvolutionLayer.cpp | 14 +- tests/validation/CL/DilatedConvolutionLayer.cpp | 22 +- tests/validation/NEON/ConvolutionLayer.cpp | 6 +- tests/validation/NEON/DilatedConvolutionLayer.cpp | 14 +- .../validation/fixtures/ConvolutionLayerFixture.h | 7 +- 17 files changed, 380 insertions(+), 174 deletions(-) diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index 0c3d9414e1..0340e1a644 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -136,6 +136,20 @@ public: // Make sure all empty dimensions are filled with 1 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); } + /** Shifts right the tensor shape increasing its dimensions + * + * @param[in] step Rotation step + */ + void shift_right(size_t step) + { + ARM_COMPUTE_ERROR_ON(step > TensorShape::num_max_dimensions - num_dimensions()); + + std::rotate(begin(), begin() + TensorShape::num_max_dimensions - step, end()); + _num_dimensions += step; + + // Correct number dimensions to ignore trailing dimensions of size 1 + apply_dimension_correction(); + } /** Return a copy with collapsed dimensions starting from a given point. * diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index f64cf9d6ae..115cbe688d 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -110,6 +110,7 @@ inline TensorShape compute_reductionB_shape(const ITensorInfo &a) inline TensorShape compute_col2im_shape(const ITensorInfo &input, std::pair convolved_dims) { TensorShape col2im_shape{ input.tensor_shape() }; + col2im_shape.shift_right(1); col2im_shape.set(0, convolved_dims.first); col2im_shape.set(1, convolved_dims.second); col2im_shape.set(2, input.tensor_shape()[0]); diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index 3dde52989b..2c1f7a9d5e 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -158,22 +158,24 @@ public: private: /** Configures the appropriate matrix multiply routine * - * @param input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param weights Weights tensor. Data type supported: Same as @p input. - * @param output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in, out] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) */ - void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth = 1); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines * - * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param[in] weights Weights tensor. Data type supported: Same as @p input. - * @param[in] output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * * @return a status */ - static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output); + static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1); private: CLMemoryGroup _memory_group; @@ -192,9 +194,12 @@ private: CLTensor _gemm_output; CLTensor _tmp_output; + DataLayout _data_layout; + + bool _skip_im2col; bool _is_quantized; bool _is_activationlayer_enabled; bool _is_prepared; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */ diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index 97e9e1057b..712a1179a6 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -329,7 +329,8 @@ const std::map CLKernelLibrary::_kernel_program_map = { "remap_nearest_neighbour", "remap.cl" }, { "remap_bilinear", "remap.cl" }, { "reshape_layer", "reshape_layer.cl" }, - { "reshape_to_columns", "convolution_layer.cl" }, + { "reshape_to_columns_nchw", "convolution_layer.cl" }, + { "reshape_to_columns_nhwc", "convolution_layer.cl" }, { "RGB888_to_IYUV_bt709", "color_convert.cl" }, { "RGB888_to_NV12_bt709", "color_convert.cl" }, { "RGB888_to_RGBA8888_bt709", "color_convert.cl" }, diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl index 9b5a7b5b7e..6e491f33cf 100644 --- a/src/core/CL/cl_kernels/col2im.cl +++ b/src/core/CL/cl_kernels/col2im.cl @@ -52,8 +52,6 @@ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) - * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) @@ -66,11 +64,11 @@ * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes) */ __kernel void col2im( - TENSOR3D_DECLARATION(src), + IMAGE_DECLARATION(src), TENSOR3D_DECLARATION(dst), uint dst_stride_w) { - Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); + Image src = CONVERT_TO_IMAGE_STRUCT(src); VEC_DATA_TYPE(DATA_TYPE, 8) data = vload8(0, (__global DATA_TYPE *)src.ptr); @@ -113,8 +111,6 @@ __kernel void col2im( * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) - * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) @@ -127,11 +123,11 @@ __kernel void col2im( * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes) */ __kernel void col2im( - TENSOR3D_DECLARATION(src), + IMAGE_DECLARATION(src), TENSOR3D_DECLARATION(dst), uint dst_stride_w) { - Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); + Image src = CONVERT_TO_IMAGE_STRUCT(src); Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(dst); // Compute output offset diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl index f8e0c27724..6a70b009c8 100644 --- a/src/core/CL/cl_kernels/convolution_layer.cl +++ b/src/core/CL/cl_kernels/convolution_layer.cl @@ -55,7 +55,7 @@ * @param[in] depth The depth of the input tensor * @param[in] total_filters Total number of filters. 4th dimension of the weights matrix */ -__kernel void reshape_to_columns( +__kernel void reshape_to_columns_nchw( TENSOR3D_DECLARATION(src), IMAGE_DECLARATION(dst), #ifdef HAS_BIAS @@ -97,4 +97,74 @@ __kernel void reshape_to_columns( } } } + +/** This kernel reshapes the tensor's low three dimensions to single column + * + * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short + * + * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32 + * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[out] dst_ptr Pointer to the destination tensor. Same as @p src_ptr + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] bias_ptr Pointer to the bias tensor. Same as @p src_ptr + * @param[in] bias_stride_x Stride of the bias tensor in X dimension (in bytes) + * @param[in] bias_step_x bias_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] bias_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[in] depth The depth of the input tensor + * @param[in] width The width of the input tensor + * @param[in] height The height of the input tensor + * @param[in] total_filters Total number of filters. 4th dimension of the weights matrix + */ +__kernel void reshape_to_columns_nhwc( + TENSOR3D_DECLARATION(src), + IMAGE_DECLARATION(dst), +#ifdef HAS_BIAS + VECTOR_DECLARATION(bias), +#endif /* HAS_BIAS */ + uint depth, uint width, uint height, uint total_filters) +{ + Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); + bool is_last_thread = (get_global_id(0) == (get_global_size(0) - 1) && get_global_id(1) == (get_global_size(1) - 1) && get_global_id(2) == (get_global_size(2) - 1)); + + __global uchar *tmp_src_ptr = src.ptr; + __global uchar *tmp_dst_ptr = dst_ptr + dst_offset_first_element_in_bytes + get_global_id(1) * dst_stride_y + get_global_id(2) * width * dst_stride_y + get_global_id( + 0) * width * height * dst_stride_y; +#ifdef HAS_BIAS + __global uchar *tmp_bias_ptr = bias_ptr + bias_offset_first_element_in_bytes; +#endif /* HAS_BIAS */ + + if(is_last_thread) + { + for(uint i = 0; i < total_filters; ++i) + { + *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr); + +#ifdef HAS_BIAS + *((__global DATA_TYPE *)(tmp_dst_ptr + dst_stride_y)) = *((__global DATA_TYPE *)(tmp_bias_ptr)); + tmp_bias_ptr += bias_stride_x; +#endif /* HAS_BIAS */ + tmp_src_ptr += height * src_stride_z; + tmp_dst_ptr += dst_stride_x; + } + } + else + { + for(uint i = 0; i < total_filters; ++i) + { + *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr); + tmp_src_ptr += height * src_stride_z; + tmp_dst_ptr += dst_stride_x; + } + } +} #endif // defined(DATA_TYPE) \ No newline at end of file diff --git a/src/core/CL/cl_kernels/im2col.cl b/src/core/CL/cl_kernels/im2col.cl index c60c9a996c..6f25ad4b7a 100644 --- a/src/core/CL/cl_kernels/im2col.cl +++ b/src/core/CL/cl_kernels/im2col.cl @@ -136,6 +136,7 @@ __kernel void im2col1x1_stridex1_dchw( * @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2 * @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0 * @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1 + * @note The dilation_x and dilation_y must be passed at compile time using -DDILATION_X and -DDILATION_Y: e.g. -DDILATION_X=1, -DDILATION_Y=1 * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row. * * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32 @@ -182,16 +183,18 @@ __kernel void im2col_generic_nhwc( for(int yk = 0; yk < KERNEL_HEIGHT; ++yk) { - const int y0 = yi + yk; + const int dilated_offset_y = yk * DILATION_Y; + const int y0 = yi + dilated_offset_y; if(y0 >= 0 && y0 < SRC_HEIGHT) { int xk; for(xk = 0; xk < KERNEL_WIDTH; xk++) { - const int x0 = xi + xk; + const int dilated_offset_x = xk * DILATION_X; + const int x0 = xi + dilated_offset_x; if(x0 >= 0 && x0 < SRC_WIDTH) { - *((__global DATA_TYPE *)output_ptr) = PTR_TO_VALUE(input_ptr + xk * src_stride_y + yk * src_stride_z, DATA_TYPE); + *((__global DATA_TYPE *)output_ptr) = PTR_TO_VALUE(input_ptr + dilated_offset_x * src_stride_y + dilated_offset_y * src_stride_z, DATA_TYPE); } else { diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp index 4e444206f1..64e6a0b7d8 100644 --- a/src/core/CL/kernels/CLCol2ImKernel.cpp +++ b/src/core/CL/kernels/CLCol2ImKernel.cpp @@ -140,23 +140,25 @@ void CLCol2ImKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window); - // The collapse method rely on the assumption that the third dimension of input buffer is 1 - ARM_COMPUTE_ERROR_ON(window.z().end() != 1); + + Window out_window; + out_window.use_tensor_dimensions(_output->info()->tensor_shape()); Window collapsed_window = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - Window slice = collapsed_window.first_slice_window_3D(); + Window slice = collapsed_window.first_slice_window_2D(); + Window slice_out = out_window.first_slice_window_3D(); // Set static kernel arguments - unsigned int idx = 2 * num_arguments_per_3D_tensor(); + unsigned int idx = num_arguments_per_2D_tensor() + num_arguments_per_3D_tensor(); _kernel.setArg(idx++, _output->info()->strides_in_bytes()[3]); do { // Set inputs unsigned int idx = 0; - add_3D_tensor_argument(idx, _input, slice); - add_3D_tensor_argument(idx, _output, slice); + add_2D_tensor_argument(idx, _input, slice); + add_3D_tensor_argument(idx, _output, slice_out); enqueue(queue, *this, slice, _lws_hint); } - while(collapsed_window.slide_window_slice_3D(slice)); + while(collapsed_window.slide_window_slice_2D(slice) && out_window.slide_window_slice_3D(slice_out)); } diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp index 328b39681b..21deb9217c 100644 --- a/src/core/CL/kernels/CLIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLIm2ColKernel.cpp @@ -143,7 +143,7 @@ CLIm2ColKernel::configure_window(const ICLTensor *input, ICLTensor *output, cons { case 1: // Optimized im2col1x1 if stride_x = 1 and conv_info.has_padding() = false - if(conv_info.stride().first == 1 && !conv_info.has_padding()) + if(conv_info.stride().first == 1 && !conv_info.has_padding() && data_layout == DataLayout::NCHW) { // Set hint for LWS _lws_hint = cl::NDRange(1, 1, 8); @@ -350,11 +350,14 @@ void CLIm2ColKernel::run_generic(const Window &window, cl::CommandQueue &queue) // Change the Z dimension's step back to 1 window_collapsed.set_dimension_step(Window::DimZ, 1); + Window window_output; + window_output.use_tensor_dimensions(_output->info()->tensor_shape()); + const Window first_slice_3d = window_collapsed.first_slice_window_3D(); Window slice = first_slice_3d; Window slice_in = first_slice_3d; - Window slice_out = first_slice_3d; + Window slice_out = window_output.first_slice_window_2D(); const bool out_dim_not_same_input_dim = _convolved_dims.first != _input->info()->dimension(width_idx) || _convolved_dims.second != _input->info()->dimension(height_idx); @@ -386,21 +389,16 @@ void CLIm2ColKernel::run_generic(const Window &window, cl::CommandQueue &queue) slice_in.set(Window::DimY, Window::Dimension(0, 0, 0)); slice_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - // Setup output slice - slice_out.set(Window::DimX, Window::Dimension(0, _output->info()->dimension(0), _kernel_dims.area())); - slice_out.set(Window::DimY, Window::Dimension(0, _output->info()->dimension(1), _output->info()->dimension(1))); - slice_out.set(Window::DimZ, Window::Dimension(0, 1, 1)); - do { unsigned int idx = 0; add_3D_tensor_argument(idx, _input, slice_in); add_2D_tensor_argument(idx, _output, slice_out); _kernel.setArg(idx++, static_cast(_input->info()->strides_in_bytes()[3])); - _kernel.setArg(idx++, static_cast(_output->info()->strides_in_bytes()[3])); + _kernel.setArg(idx++, static_cast(_output->info()->strides_in_bytes()[2])); enqueue(queue, *this, slice, _lws_hint); } - while(window_collapsed.slide_window_slice_3D(slice) && window_collapsed.slide_window_slice_3D(slice_out) && window_collapsed.slide_window_slice_3D(slice_in)); + while(window_collapsed.slide_window_slice_3D(slice) && window_output.slide_window_slice_2D(slice_out) && window_collapsed.slide_window_slice_3D(slice_in)); } void CLIm2ColKernel::run_reduced(const Window &window, cl::CommandQueue &queue) diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp index c0a4517ad3..b012d58d59 100644 --- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp +++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp @@ -85,7 +85,8 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * (biases != nullptr) ? biases->info() : nullptr, output->info())); - const DataType data_type = input->info()->data_type(); + const DataType data_type = input->info()->data_type(); + const DataLayout data_layout = input->info()->data_layout(); _biases = biases; _output = output; @@ -98,7 +99,8 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position())); // Create kernel - _kernel = static_cast(CLKernelLibrary::get().create_kernel("reshape_to_columns", build_opts.options())); + std::string kernel_name = std::string("reshape_to_columns_") + lower_string(string_from_data_layout(data_layout)); + _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); // Set static arguments unsigned int idx = num_arguments_per_3D_tensor() + num_arguments_per_2D_tensor(); diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 82710b6461..ace3379618 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -67,9 +67,10 @@ Status CLConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, co if(biases != nullptr) { + const int idx_kernels = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::BATCHES); ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(weights->data_type())); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); - ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels)); ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } @@ -91,11 +92,12 @@ void CLConvolutionLayerReshapeWeights::run() CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr memory_manager) : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(), - _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false) + _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false), _is_quantized(false), + _is_activationlayer_enabled(false), _is_prepared(false) { } -void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output) +void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info())); @@ -119,15 +121,15 @@ void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTenso else { // Configure matrix multiply function - _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/)); + _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth)); } } -Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output) +Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth) { const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */); + const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth); if(is_quantized) { // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() @@ -165,18 +167,32 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * dilation, act_info)); + const DataType data_type = input->info()->data_type(); + const DataLayout data_layout = input->info()->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); + + const unsigned int kernel_width = weights->info()->dimension(idx_width); + const unsigned int kernel_height = weights->info()->dimension(idx_height); + _is_prepared = weights_info.retain_internal_weights(); _original_weights = weights; _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); - - const DataType dt = input->info()->data_type(); + _data_layout = data_layout; + _skip_im2col = false; // Set the GPU target for im2col and col2im _im2col_kernel.set_target(CLScheduler::get().target()); _col2im_kernel.set_target(CLScheduler::get().target()); - const bool append_bias = (biases != nullptr) && (!_is_quantized); + bool is_nhwc = _data_layout == DataLayout::NHWC; + const ICLTensor *gemm_input_to_use = input; + ICLTensor *gemm_output_to_use = output; + ICLTensor *gemm_output_staged_to_use = output; + const bool append_bias = (biases != nullptr) && (!_is_quantized); const unsigned bias_element = (append_bias) ? 1 : 0; const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr; @@ -188,14 +204,15 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // Get convolved dimensions unsigned int conv_w = 0; unsigned int conv_h = 0; + std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(idx_width), + input->info()->dimension(idx_height), + kernel_width, + kernel_height, + conv_info, + dilation); - const unsigned int kernel_width = weights->info()->dimension(0); - const unsigned int kernel_height = weights->info()->dimension(1); - std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_width, kernel_height, - conv_info, dilation); - - unsigned int mat_weights_cols = weights->info()->dimension(3); - unsigned int mat_weights_rows = weights->info()->dimension(0) * weights->info()->dimension(1) * weights->info()->dimension(2) + bias_element; + unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels); + unsigned int mat_weights_rows = weights->info()->dimension(idx_width) * weights->info()->dimension(idx_height) * weights->info()->dimension(idx_channel) + bias_element; // _weights_reshaped will be auto configured in the kernel. // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM @@ -204,38 +221,58 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * weights = &_weights_reshaped; // Create tensor to store im2col reshaped inputs - const unsigned int mat_input_cols = mat_weights_rows; - const unsigned int mat_input_rows = conv_w * conv_h; - TensorShape shape_im2col = input->info()->tensor_shape(); - shape_im2col.set(0, mat_input_cols); - shape_im2col.set(1, mat_input_rows); - shape_im2col.set(2, 1); - // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. - TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->info()->fixed_point_position()); - im2col_reshaped_info.set_quantization_info(input->info()->quantization_info()); - _im2col_output.allocator()->init(im2col_reshaped_info); - _memory_group.manage(&_im2col_output); + if(!_skip_im2col) + { + // Calculate im2col shape + TensorShape shape_im2col = input->info()->tensor_shape(); + if(shape_im2col.num_dimensions() >= 3) + { + shape_im2col.remove_dimension(2); + } + shape_im2col.set(0, mat_weights_rows); + shape_im2col.set(1, conv_w * conv_h); + + // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. + TensorInfo im2col_reshaped_info(shape_im2col, 1, data_type, input->info()->fixed_point_position()); + im2col_reshaped_info.set_quantization_info(input->info()->quantization_info()); + _im2col_output.allocator()->init(im2col_reshaped_info); + _memory_group.manage(&_im2col_output); + + // Configure and tune im2col + _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation); + CLScheduler::get().tune_kernel_static(_im2col_kernel); + + // Update GEMM input + gemm_input_to_use = &_im2col_output; + } // Create GEMM output tensor - TensorShape shape_gemm = _im2col_output.info()->tensor_shape(); - shape_gemm.set(0, mat_weights_cols); - shape_gemm.set(1, mat_input_rows); - const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt; - // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. - // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. - TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position()); - info_gemm.set_quantization_info(output->info()->quantization_info()); - _gemm_output.allocator()->init(info_gemm); - _memory_group.manage(&_gemm_output); - - // Configure and tune im2col - _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation); - CLScheduler::get().tune_kernel_static(_im2col_kernel); + if(!is_nhwc || _is_quantized) + { + // Calculate GEMM output shape + TensorShape shape_gemm = _im2col_output.info()->tensor_shape(); + shape_gemm.set(0, mat_weights_cols); + shape_gemm.set(1, conv_w * conv_h); + + // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. + const DataType gemm_data_type = _is_quantized ? DataType::S32 : data_type; + // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. + TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position()); + info_gemm.set_quantization_info(output->info()->quantization_info()); + _gemm_output.allocator()->init(info_gemm); + _memory_group.manage(&_gemm_output); + + // Update GEMM output + gemm_output_to_use = &_gemm_output; + } // Configure and tune GEMM - configure_mm(&_im2col_output, weights, &_gemm_output); + configure_mm(gemm_input_to_use, weights, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1); - _im2col_output.allocator()->allocate(); + if(!_skip_im2col) + { + _im2col_output.allocator()->allocate(); + } // Configure output stage for quantized case if(_is_quantized) @@ -245,20 +282,33 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale; int output_multiplier, output_shift; quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); - _memory_group.manage(&_tmp_output); - _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset); + if(!is_nhwc) + { + _memory_group.manage(&_tmp_output); + gemm_output_staged_to_use = &_tmp_output; + } + _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset); } - // Configure and tune Col2Im - _col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h)); - CLScheduler::get().tune_kernel_static(_col2im_kernel); - if(_is_quantized) + if(!is_nhwc) + { + // Configure and tune Col2Im + _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h)); + CLScheduler::get().tune_kernel_static(_col2im_kernel); + } + + if(_is_quantized && !is_nhwc) { _tmp_output.allocator()->allocate(); } - _gemm_output.allocator()->allocate(); - ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(0) != conv_w) || (output->info()->dimension(1) != conv_h), "Output shape does not match the expected one"); + if(!is_nhwc || _is_quantized) + { + _gemm_output.allocator()->allocate(); + } + + ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h), + "Output shape does not match the expected one"); //Configure Activation Layer _is_activationlayer_enabled = act_info.enabled(); @@ -278,83 +328,128 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2)); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8 && input->data_layout() == DataLayout::NHWC, + "NHWC is unsupported for QASYMM8!"); + + const DataLayout data_layout = input->data_layout(); + const DataType data_type = input->data_type(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); + + const unsigned int kernel_width = weights->dimension(idx_width); + const unsigned int kernel_height = weights->dimension(idx_height); + + TensorInfo im2col_reshaped_info, info_gemm, tmp_info, weights_reshaped_info; + const ITensorInfo *gemm_input_to_use = input; + const ITensorInfo *gemm_output_to_use = output; + const ITensorInfo *gemm_output_staged_to_use = output; + const ITensorInfo *weights_to_use = weights; + + const bool is_nhwc = data_layout == DataLayout::NHWC; + const bool skip_im2col = false; + const bool is_quantized = is_data_type_quantized_asymmetric(data_type); + const bool append_bias = (biases != nullptr) && (!is_quantized); + const unsigned bias_element = (append_bias) ? 1 : 0; + + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); + // Validate biases + if(biases != nullptr) + { + if(is_quantized) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + } + if(act_info.enabled()) { ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a()); } - const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - const bool append_bias = (biases != nullptr) && (!is_quantized); - const unsigned bias_element = (append_bias) ? 1 : 0; - const DataType dt = input->data_type(); - // Get convolved dimensions unsigned int conv_w = 0; unsigned int conv_h = 0; - const unsigned int kernel_width = weights->dimension(0); - const unsigned int kernel_height = weights->dimension(1); + std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width), + input->dimension(idx_height), + kernel_width, + kernel_height, + conv_info, + dilation); - std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, conv_info, dilation); - - unsigned int mat_weights_cols = weights->dimension(3); - unsigned int mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + bias_element; + unsigned int mat_weights_cols = weights->dimension(idx_kernels); + unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element; + // Output tensor auto inizialitation if not yet initialized ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr)); + weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type, weights->fixed_point_position()); + weights_to_use = &weights_reshaped_info; - // Create tensor info for im2col reshaped inputs - const unsigned int mat_input_cols = mat_weights_rows; - const unsigned int mat_input_rows = conv_w * conv_h; - TensorShape shape_im2col = input->tensor_shape(); - shape_im2col.set(0, mat_input_cols); - shape_im2col.set(1, mat_input_rows); - shape_im2col.set(2, 1); - TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->fixed_point_position()); - im2col_reshaped_info.set_quantization_info(input->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + if(!skip_im2col) + { + // Create tensor info for im2col reshaped inputs + TensorShape shape_im2col = input->tensor_shape(); + if(input->tensor_shape().num_dimensions() >= 3) + { + shape_im2col.remove_dimension(2); + } + shape_im2col.set(0, mat_weights_rows); + shape_im2col.set(1, conv_w * conv_h); + im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type, input->fixed_point_position()); + im2col_reshaped_info.set_quantization_info(input->quantization_info()); + ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + gemm_input_to_use = &im2col_reshaped_info; + } // Create GEMM output tensor - TensorShape shape_gemm = im2col_reshaped_info.tensor_shape(); - shape_gemm.set(0, mat_weights_cols); - shape_gemm.set(1, mat_input_rows); - const DataType gemm_data_type = is_quantized ? DataType::S32 : dt; - // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. - TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->fixed_point_position()); - info_gemm.set_quantization_info(output->quantization_info()); - - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(&im2col_reshaped_info, weights, &info_gemm)); - TensorInfo tmp_info(shape_gemm, 1, DataType::QASYMM8, input->fixed_point_position()); - tmp_info.set_quantization_info(output->quantization_info()); + if(!is_nhwc || is_quantized) + { + TensorShape shape_gemm = gemm_input_to_use->tensor_shape(); + shape_gemm.set(0, mat_weights_cols); + shape_gemm.set(1, conv_w * conv_h); + const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type; + // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. + info_gemm = TensorInfo(shape_gemm, 1, gemm_data_type, input->fixed_point_position()); + info_gemm.set_quantization_info(output->quantization_info()); + gemm_output_to_use = &info_gemm; + } + + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1)); if(is_quantized) { - float multiplier = input->quantization_info().scale * weights->quantization_info().scale / output->quantization_info().scale; + float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale; int output_multiplier, output_shift; quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); + if(!is_nhwc) + { + tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8, input->fixed_point_position()); + tmp_info.set_quantization_info(output->quantization_info()); + gemm_output_staged_to_use = &tmp_info; + } // Validate output stage for quantized case - CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&info_gemm, biases, &tmp_info, output->quantization_info().offset); + CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset); } // Validate Col2Im - ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? &tmp_info : &info_gemm, output, std::make_pair(conv_w, conv_h))); - - if(biases != nullptr) + if(!is_nhwc) { - if(is_quantized) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); - } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases); - ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3)); - ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, + output, + std::make_pair(conv_w, conv_h))); } //Validate Activation Layer @@ -373,7 +468,10 @@ void CLGEMMConvolutionLayer::run() _memory_group.acquire(); // Run im2col - CLScheduler::get().enqueue(_im2col_kernel); + if(!_skip_im2col) + { + CLScheduler::get().enqueue(_im2col_kernel); + } // Runs CLGEMM or CLGEMMLowpMatrixMultiplyCore functions if(_is_quantized) @@ -391,7 +489,10 @@ void CLGEMMConvolutionLayer::run() } // Reshape output matrix - CLScheduler::get().enqueue(_col2im_kernel, false); + if(_data_layout == DataLayout::NCHW) + { + CLScheduler::get().enqueue(_col2im_kernel, false); + } //Run Activation Layer if enabled if(_is_activationlayer_enabled) diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp index d15e5dfa3d..40bf032d69 100644 --- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp @@ -48,7 +48,10 @@ void calculate_shapes(const ITensorInfo *input, const ITensorInfo *weights, cons // Get convolved dimensions unsigned int conv_w = 0; unsigned int conv_h = 0; - std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, + std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), + input->dimension(1), + kernel_width, + kernel_height, conv_info); const size_t mat_weights_cols = weights->dimension(3); @@ -61,9 +64,12 @@ void calculate_shapes(const ITensorInfo *input, const ITensorInfo *weights, cons const size_t mat_input_rows = conv_w * conv_h; shape_im2col = input->tensor_shape(); + if(shape_im2col.num_dimensions() >= 3) + { + shape_im2col.remove_dimension(2); + } shape_im2col.set(0, mat_input_cols); shape_im2col.set(1, mat_input_rows); - shape_im2col.set(2, 1); shape_gemm = shape_im2col; shape_gemm.set(0, mat_weights_cols); diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp index 242c252015..7fd29f4d69 100644 --- a/tests/validation/CL/ConvolutionLayer.cpp +++ b/tests/validation/CL/ConvolutionLayer.cpp @@ -205,7 +205,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMConvolutionLayerFixture, framework: framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -216,7 +216,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMConvolutionLayerFixture, framework: framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -230,7 +230,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMConvolutionLayerFixture, framework framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -241,7 +241,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMConvolutionLayerFixture, framework framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -266,18 +266,20 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qasymm8); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeConvolutionLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 0) })), QuantizedActivationFunctionsDataset)) { diff --git a/tests/validation/CL/DilatedConvolutionLayer.cpp b/tests/validation/CL/DilatedConvolutionLayer.cpp index 784e2001c1..4b22390b08 100644 --- a/tests/validation/CL/DilatedConvolutionLayer.cpp +++ b/tests/validation/CL/DilatedConvolutionLayer.cpp @@ -164,7 +164,7 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { // Validate output @@ -173,7 +173,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerFixture, fra FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMDilatedConvolutionLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { // Validate output @@ -185,7 +185,7 @@ TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { // Validate output @@ -194,7 +194,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerFixture, fr FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMDilatedConvolutionLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { // Validate output @@ -212,9 +212,10 @@ using CLGEMMDilatedConvolutionLayerQuantizedFixture = ConvolutionValidationQuant TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), + combine(combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() }))) { @@ -222,9 +223,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMDilatedConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), + combine(combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 0) })), framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() }))) { diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 747d8d2f62..94b38c2c81 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -259,18 +259,20 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ }); TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeConvolutionLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { diff --git a/tests/validation/NEON/DilatedConvolutionLayer.cpp b/tests/validation/NEON/DilatedConvolutionLayer.cpp index 2888a6535e..e703c67868 100644 --- a/tests/validation/NEON/DilatedConvolutionLayer.cpp +++ b/tests/validation/NEON/DilatedConvolutionLayer.cpp @@ -206,9 +206,10 @@ using NEGEMMDilatedConvolutionLayerQuantizedFixture = ConvolutionValidationQuant TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMDilatedConvolutionLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), + combine(combine(combine(combine(combine(datasets::SmallDilatedConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { @@ -216,9 +217,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMDilatedConvolutionLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), + combine(combine(combine(combine(combine(datasets::LargeDilatedConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), framework::dataset::make("ActivationLayerInfo", ActivationLayerInfo()))) { diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 93de24d1bd..00ca0778f5 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -214,11 +214,10 @@ class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGeneri public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, - QuantizationInfo quantization_info, ActivationLayerInfo act_info) + DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, - DataLayout::NCHW, 0, - quantization_info, act_info); + ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_layout, 0, quantization_info, act_info); } }; } // namespace validation -- cgit v1.2.1