From d7295b7079f6b9126596cea998146ca9c6e87706 Mon Sep 17 00:00:00 2001 From: Dmitry Savenko Date: Mon, 20 Nov 2017 22:00:08 +0700 Subject: COMPMID-661: Add QASYMM8 support (and basic tests) to CLDepthwiseConvolution3x3 kernel (#28) Change-Id: I51bebe74e3814c1245812ad575fe7854d460674f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/109864 Reviewed-by: Anthony Barbier Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- .../CL/kernels/CLDepthwiseConvolution3x3Kernel.h | 2 +- .../runtime/CL/functions/CLDepthwiseConvolution.h | 2 +- src/core/CL/CLKernelLibrary.cpp | 5 + .../cl_kernels/depthwise_convolution_quantized.cl | 258 +++++++++++++++++++++ src/core/CL/cl_kernels/helpers_asymm.h | 2 + .../CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp | 45 +++- .../CL/functions/CLDepthwiseConvolution.cpp | 6 +- tests/validation/CL/DepthwiseConvolution.cpp | 41 +++- tests/validation/CPP/ConvolutionLayer.cpp | 4 +- tests/validation/CPP/DepthwiseConvolution.cpp | 83 ++++++- tests/validation/CPP/DepthwiseConvolution.h | 4 +- tests/validation/NEON/DepthwiseConvolution.cpp | 8 +- .../fixtures/DepthwiseConvolutionFixture.h | 89 +++++-- .../fixtures/DirectConvolutionLayerFixture.h | 6 + 14 files changed, 508 insertions(+), 47 deletions(-) create mode 100644 src/core/CL/cl_kernels/depthwise_convolution_quantized.cl diff --git a/arm_compute/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.h b/arm_compute/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.h index 1eca33ffb7..f9689a4329 100644 --- a/arm_compute/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.h +++ b/arm_compute/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.h @@ -47,7 +47,7 @@ public: CLDepthwiseConvolution3x3Kernel &operator=(CLDepthwiseConvolution3x3Kernel &&) = default; /** Initialize the function's source, destination, conv and border_size. * - * @param[in] input Source tensor. DataType supported: F32. + * @param[in] input Source tensor. DataType supported: QASYMM8/F32. * @param[in] weights Weights tensor. A 3D tensor with dimensions [3, 3, IFM]. Data type supported: Same as @p input. * @param[in] biases (Optional) Biases tensor. A 1D tensor with dimensions [IFM]. Must be nullptr if not needed. * Data type supported: Same as @p input. diff --git a/arm_compute/runtime/CL/functions/CLDepthwiseConvolution.h b/arm_compute/runtime/CL/functions/CLDepthwiseConvolution.h index 7c35c2a4af..40eb8523fb 100644 --- a/arm_compute/runtime/CL/functions/CLDepthwiseConvolution.h +++ b/arm_compute/runtime/CL/functions/CLDepthwiseConvolution.h @@ -51,7 +51,7 @@ public: CLDepthwiseConvolution3x3(); /** Initialize the function's source, destination, conv and border_size. * - * @param[in, out] input Source tensor. Data type supported: F32. (Written to only for border filling). + * @param[in, out] input Source tensor. Data type supported: QASYMM8/F32. (Written to only for border filling). * @param[in] weights Weights tensor. A 3D tensor with shape [3, 3, IFM]. Data type supported: Same as @p input. * @param[in] biases (Optional) Biases tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed. * Data type supported: Same as @p input. diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index 94cc02a705..9a2bb81708 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -187,6 +187,7 @@ const std::map CLKernelLibrary::_kernel_program_map = { "copy_planes_3p", "channel_combine.cl" }, { "copy_to_keypoint", "fast_corners.cl" }, { "depthwise_convolution_3x3", "depthwise_convolution.cl" }, + { "depthwise_convolution_3x3_quantized", "depthwise_convolution_quantized.cl" }, { "depthwise_im2col", "depthwise_convolution.cl" }, { "depthwise_vector_to_tensor", "depthwise_convolution.cl" }, { "depthwise_weights_reshape", "depthwise_convolution.cl" }, @@ -417,6 +418,10 @@ const std::map CLKernelLibrary::_program_source_map = { "depthwise_convolution.cl", #include "./cl_kernels/depthwise_convolution.clembed" + }, + { + "depthwise_convolution_quantized.cl", +#include "./cl_kernels/depthwise_convolution_quantized.clembed" }, { "dequantization_layer.cl", diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl new file mode 100644 index 0000000000..19a509bd0a --- /dev/null +++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "helpers_asymm.h" + +#if defined(CONV_STRIDE_X) + +#if CONV_STRIDE_X == 1 +#define convolution1x3 convolution1x3_stride_1 +#elif CONV_STRIDE_X == 2 +#define convolution1x3 convolution1x3_stride_2 +#elif CONV_STRIDE_X == 3 +#define convolution1x3 convolution1x3_stride_3 +#else /* CONV_STRIDE_X */ +#error "Stride not supported" +#endif /* CONV_STRIDE_X */ + +/** Compute a 1D horizontal convolution of size 3 and stride 1 for uchar type. + * + * @param[in] left_pixel Pointer to the left pixel. + * @param[in] left_coeff Weight of the left pixel + * @param[in] middle_coeff Weight of the middle pixel + * @param[in] right_coeff Weight of the right pixel + * @param[in] input_offset Quantized offset of zero point of the input tensor data range + * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range + * + * @return a int2 containing 2 convoluted values. + */ +inline int2 convolution1x3_stride_1(__global const uchar *left_pixel, + const int left_coeff, + const int middle_coeff, + const int right_coeff, + const int input_offset, + const int weight_offset) +{ + int4 temp = CONVERT(vload4(0, left_pixel), int4); + + int2 left = CONVERT(temp.s01, int2); + int2 middle = CONVERT(temp.s12, int2); + int2 right = CONVERT(temp.s23, int2); + + return (left + input_offset) * (int2)(left_coeff + weight_offset) + (middle + input_offset) * (int2)(middle_coeff + weight_offset) + (right + input_offset) * (int2)(right_coeff + weight_offset); +} + +/** Compute a 1D horizontal convolution of size 3 and stride 2 for uchar type. + * + * @param[in] left_pixel Pointer to the left pixel. + * @param[in] left_coeff Weight of the left pixel + * @param[in] middle_coeff Weight of the middle pixel + * @param[in] right_coeff Weight of the right pixel + * @param[in] input_offset Quantized offset of zero point of the input tensor data range + * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range + * + * @return a int2 containing 2 convoluted values. + */ +inline int2 convolution1x3_stride_2(__global const uchar *left_pixel, + const int left_coeff, + const int middle_coeff, + const int right_coeff, + const int input_offset, + const int weight_offset) +{ + int4 temp0 = CONVERT(vload4(0, left_pixel), int4); + int temp1 = CONVERT(*(left_pixel + 4 * sizeof(uchar)), int); + + int2 left = CONVERT(temp0.s02, int2); + int2 middle = CONVERT(temp0.s13, int2); + int2 right = CONVERT((int2)(temp0.s2, temp1), int2); + + return (left + input_offset) * (int2)(left_coeff + weight_offset) + (middle + input_offset) * (int2)(middle_coeff + weight_offset) + (right + input_offset) * (int2)(right_coeff + weight_offset); +} + +/** Compute a 1D horizontal convolution of size 3 and stride 3 for uchar type. + * + * @param[in] left_pixel Pointer to the left pixel. + * @param[in] left_coeff Weight of the left pixel + * @param[in] middle_coeff Weight of the middle pixel + * @param[in] right_coeff Weight of the right pixel + * @param[in] input_offset Quantized offset of zero point of the input tensor data range + * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range + * + * @return a int2 containing 2 convoluted values. + */ +inline int2 convolution1x3_stride_3(__global const uchar *left_pixel, + const int left_coeff, + const int middle_coeff, + const int right_coeff, + const int input_offset, + const int weight_offset) +{ + int4 temp0 = CONVERT(vload4(0, left_pixel), int4); + int2 temp1 = CONVERT(vload2(0, (left_pixel + 4 * sizeof(uchar))), int2); + + int2 left = CONVERT(temp0.s03, int2); + int2 middle = CONVERT((int2)(temp0.s1, temp1.s0), int2); + int2 right = CONVERT((int2)(temp0.s2, temp1.s1), int2); + + return (left + input_offset) * (int2)(left_coeff + weight_offset) + (middle + input_offset) * (int2)(middle_coeff + weight_offset) + (right + input_offset) * (int2)(right_coeff + weight_offset); +} + +/** Apply a 3x3 convolution matrix to a single channel QASYMM8 input image and return the result. + * + * Convolution matrix layout: + * + * [ mat0, mat1, mat2 ]\n + * [ mat3, mat4, mat5 ]\n + * [ mat6, mat7, mat8 ]\n + * + * @param[in] src A pointer to source Image structure + * @param[in] mat0 Coefficient from the convolution matrix + * @param[in] mat1 Coefficient from the convolution matrix + * @param[in] mat2 Coefficient from the convolution matrix + * @param[in] mat3 Coefficient from the convolution matrix + * @param[in] mat4 Coefficient from the convolution matrix + * @param[in] mat5 Coefficient from the convolution matrix + * @param[in] mat6 Coefficient from the convolution matrix + * @param[in] mat7 Coefficient from the convolution matrix + * @param[in] mat8 Coefficient from the convolution matrix + * @param[in] input_offset Quantized offset of zero point of the input tensor data range + * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range + * @param[in] output_offset Quantized offset of zero point of the output tensor data range + * @param[in] output_multiplier Output scale multiplier + * @param[in] output_shift Output scale divisor exponent + * @param[in] bias (Optional) Bias value + * + * @return a uchar2 containing 2 convoluted values. + */ +inline uchar2 convolution3x3( + Image *src, + const uchar mat0, const uchar mat1, const uchar mat2, + const uchar mat3, const uchar mat4, const uchar mat5, + const uchar mat6, const uchar mat7, const uchar mat8, + const int input_offset, const int weight_offset, const int output_offset, + const int output_multiplier, const int output_shift +#if defined(HAS_BIAS) + , + const int bias +#endif //defined(HAS_BIAS) +) +{ + int2 pixels; + + pixels = convolution1x3(offset(src, 0, 0), mat0, mat1, mat2, input_offset, weight_offset); + pixels += convolution1x3(offset(src, 0, 1), mat3, mat4, mat5, input_offset, weight_offset); + pixels += convolution1x3(offset(src, 0, 2), mat6, mat7, mat8, input_offset, weight_offset); +#if defined(HAS_BIAS) + pixels += (int2)(bias); +#endif //defined(HAS_BIAS) + + pixels = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(pixels, output_multiplier, output_shift, 2); + pixels = pixels + output_offset; + pixels = clamp(pixels, 0, 255); + + return CONVERT(pixels, uchar2); +} + +/** This function computes the horizontal integral of the image. + * + * @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8 + * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) + * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) + * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: QASYMM8 + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: QASYMM8 + * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes) + * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes) + * @param[in] weights_step_y weights_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes) + * @param[in] weights_step_z weights_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor + * @param[in] biases_ptr (Optional) Pointer to the biases vector. Supported data types: QASYMM8 + * @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector + * @param[in] input_offset Quantized offset of zero point of the input tensor data range + * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range + * @param[in] output_offset Quantized offset of zero point of the output tensor data range + * @param[in] output_multiplier Output scale multiplier + * @param[in] output_shift Output scale divisor exponent + */ + +__kernel void depthwise_convolution_3x3_quantized( + TENSOR3D_DECLARATION(src), + TENSOR3D_DECLARATION(dst), + TENSOR3D_DECLARATION(weights), +#if defined(HAS_BIAS) + VECTOR_DECLARATION(biases), +#endif //defined(HAS_BIAS) + int input_offset, + int weight_offset, + int output_offset, + int output_multiplier, + int output_shift) +{ + Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); + Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); +#if defined(HAS_BIAS) + Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); +#endif //defined(HAS_BIAS) + + uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; + uchar3 weights_values0 = vload3(0, weights.ptr + offset.s0); + uchar3 weights_values1 = vload3(0, weights.ptr + offset.s1); + uchar3 weights_values2 = vload3(0, weights.ptr + offset.s2); + +#if defined(HAS_BIAS) + int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2)))); +#endif //defined(HAS_BIAS) + + uchar2 pixels = convolution3x3(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2, + weights_values1.s0, weights_values1.s1, weights_values1.s2, + weights_values2.s0, weights_values2.s1, weights_values2.s2, + input_offset, weight_offset, output_offset, + output_multiplier, output_shift +#if defined(HAS_BIAS) + , + bias_value +#endif //defined(HAS_BIAS) + ); + + vstore2(pixels, 0, dst.ptr); +} + +#endif //defined(CONV_STRIDE_X) diff --git a/src/core/CL/cl_kernels/helpers_asymm.h b/src/core/CL/cl_kernels/helpers_asymm.h index 3c1d58bda1..b44d0f1fd2 100644 --- a/src/core/CL/cl_kernels/helpers_asymm.h +++ b/src/core/CL/cl_kernels/helpers_asymm.h @@ -44,6 +44,7 @@ return (x >> exponent) + select(zero, one, (x & mask) > threshold); \ } +ASYMM_ROUNDING_DIVIDE_BY_POW2_IMPL(2) ASYMM_ROUNDING_DIVIDE_BY_POW2_IMPL(8) ASYMM_ROUNDING_DIVIDE_BY_POW2_IMPL(16) @@ -80,6 +81,7 @@ ASYMM_ROUNDING_DIVIDE_BY_POW2_IMPL(16) return select(ab_x2_high32, INT_MAX, overflow); \ } +ASYMM_MULT_IMP(2) ASYMM_MULT_IMP(8) ASYMM_MULT_IMP(16) diff --git a/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp index 5f42450b9f..208d06d7cd 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp @@ -33,6 +33,7 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" using namespace arm_compute; @@ -48,14 +49,22 @@ BorderSize CLDepthwiseConvolution3x3Kernel::border_size() const void CLDepthwiseConvolution3x3Kernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::F32); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, weights); ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != 3 || weights->info()->dimension(1) != 3); if(biases != nullptr) { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); + if(is_data_type_quantized_asymmetric(weights->info()->data_type())) + { + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); + } + else + { + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); + } ARM_COMPUTE_ERROR_ON(biases->info()->dimension(0) != weights->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1); } @@ -80,13 +89,12 @@ void CLDepthwiseConvolution3x3Kernel::configure(const ICLTensor *input, const IC // Set build options ARM_COMPUTE_ERROR_ON(_conv_stride_x < 1 || _conv_stride_x > 3); - std::set options{ "-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x) }; - if(_biases != nullptr) - { - options.emplace("-DHAS_BIAS"); - } + CLBuildOptions build_opts; + build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x)); + build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS"); - _kernel = static_cast(CLKernelLibrary::get().create_kernel("depthwise_convolution_3x3", options)); + std::string kernel_name = is_data_type_quantized_asymmetric(_input->info()->data_type()) ? "depthwise_convolution_3x3_quantized" : "depthwise_convolution_3x3"; + _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); // Configure kernel window const unsigned int num_elems_processed_per_iteration = 2; @@ -105,6 +113,23 @@ void CLDepthwiseConvolution3x3Kernel::configure(const ICLTensor *input, const IC output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape())); ICLKernel::configure(win); + + // Set static arguments + if(is_data_type_quantized_asymmetric(_input->info()->data_type())) + { + float multiplier = _input->info()->quantization_info().scale * _weights->info()->quantization_info().scale / _output->info()->quantization_info().scale; + int output_multiplier = 0; + int output_shift = 0; + quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); + + unsigned int idx = 3 * num_arguments_per_3D_tensor() + ((_biases != nullptr) ? num_arguments_per_1D_tensor() : 0); + + _kernel.setArg(idx++, -_input->info()->quantization_info().offset); + _kernel.setArg(idx++, -_weights->info()->quantization_info().offset); + _kernel.setArg(idx++, _output->info()->quantization_info().offset); + _kernel.setArg(idx++, output_multiplier); + _kernel.setArg(idx++, output_shift); + } } void CLDepthwiseConvolution3x3Kernel::run(const Window &window, cl::CommandQueue &queue) diff --git a/src/runtime/CL/functions/CLDepthwiseConvolution.cpp b/src/runtime/CL/functions/CLDepthwiseConvolution.cpp index 156565950a..23a20a3011 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolution.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolution.cpp @@ -37,9 +37,9 @@ CLDepthwiseConvolution3x3::CLDepthwiseConvolution3x3() void CLDepthwiseConvolution3x3::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, weights); _kernel.configure(input, weights, biases, output, conv_info); _border_handler.configure(input, _kernel.border_size(), BorderMode::CONSTANT, PixelValue(0)); diff --git a/tests/validation/CL/DepthwiseConvolution.cpp b/tests/validation/CL/DepthwiseConvolution.cpp index 5f1bde81dd..ccd9c36561 100644 --- a/tests/validation/CL/DepthwiseConvolution.cpp +++ b/tests/validation/CL/DepthwiseConvolution.cpp @@ -42,7 +42,8 @@ namespace validation { namespace { -constexpr RelativeTolerance tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +constexpr RelativeTolerance tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +constexpr RelativeTolerance tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8 */ } // namespace TEST_SUITE(CL) @@ -52,11 +53,13 @@ template using CLDepthwiseConvolutionFixture = DepthwiseConvolutionValidationFixture; TEST_SUITE(Generic) -FIXTURE_DATA_TEST_CASE(RunSmall, CLDepthwiseConvolutionFixture, framework::DatasetMode::PRECOMMIT, datasets::SmallDepthwiseConvolutionDataset()) +FIXTURE_DATA_TEST_CASE(RunSmall, CLDepthwiseConvolutionFixture, framework::DatasetMode::ALL, combine(datasets::SmallDepthwiseConvolutionDataset(), framework::dataset::make("DataType", + DataType::F32))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLDepthwiseConvolutionFixture, framework::DatasetMode::NIGHTLY, datasets::LargeDepthwiseConvolutionDataset()) +FIXTURE_DATA_TEST_CASE(RunLarge, CLDepthwiseConvolutionFixture, framework::DatasetMode::NIGHTLY, combine(datasets::LargeDepthwiseConvolutionDataset(), framework::dataset::make("DataType", + DataType::F32))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -65,16 +68,44 @@ TEST_SUITE_END() template using CLDepthwiseConvolutionFixture3x3 = DepthwiseConvolutionValidationFixture; +TEST_SUITE(Float) +TEST_SUITE(FP32) TEST_SUITE(W3x3) -FIXTURE_DATA_TEST_CASE(RunSmall, CLDepthwiseConvolutionFixture3x3, framework::DatasetMode::PRECOMMIT, datasets::SmallDepthwiseConvolutionDataset3x3()) +FIXTURE_DATA_TEST_CASE(RunSmall, CLDepthwiseConvolutionFixture3x3, framework::DatasetMode::ALL, combine(datasets::SmallDepthwiseConvolutionDataset3x3(), framework::dataset::make("DataType", + DataType::F32))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLDepthwiseConvolutionFixture3x3, framework::DatasetMode::NIGHTLY, datasets::LargeDepthwiseConvolutionDataset3x3()) +FIXTURE_DATA_TEST_CASE(RunLarge, CLDepthwiseConvolutionFixture3x3, framework::DatasetMode::NIGHTLY, combine(datasets::LargeDepthwiseConvolutionDataset3x3(), framework::dataset::make("DataType", + DataType::F32))) { validate(CLAccessor(_target), _reference, tolerance_f32); } TEST_SUITE_END() +TEST_SUITE_END() +TEST_SUITE_END() + +template +using CLDepthwiseConvolutionQuantizedFixture3x3 = DepthwiseConvolutionValidationQuantizedFixture; + +TEST_SUITE(Quantized) +TEST_SUITE(QASYMM8) +TEST_SUITE(W3x3) +FIXTURE_DATA_TEST_CASE(RunSmall, CLDepthwiseConvolutionQuantizedFixture3x3, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallDepthwiseConvolutionDataset3x3(), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127) }))) +{ + validate(CLAccessor(_target), _reference, tolerance_qasymm8); +} +FIXTURE_DATA_TEST_CASE(RunLarge, CLDepthwiseConvolutionQuantizedFixture3x3, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeDepthwiseConvolutionDataset3x3(), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127) }))) +{ + validate(CLAccessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() +TEST_SUITE_END() +TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE_END() diff --git a/tests/validation/CPP/ConvolutionLayer.cpp b/tests/validation/CPP/ConvolutionLayer.cpp index 95852b0f42..a767912879 100644 --- a/tests/validation/CPP/ConvolutionLayer.cpp +++ b/tests/validation/CPP/ConvolutionLayer.cpp @@ -55,8 +55,8 @@ void convolution3d(const SimpleTensor &in, const SimpleTensor &weights, co { const T *in_ptr = in.data() + i_offset; const T *w_ptr = weights.data() + w_offset; - const T *b_ptr = bias.data() + b_offset; - T *out_ptr = out.data() + o_offset; + const TB *b_ptr = bias.data() + b_offset; + T *out_ptr = out.data() + o_offset; const int half_width_weights = width_weights / 2; const int half_height_weights = height_weights / 2; diff --git a/tests/validation/CPP/DepthwiseConvolution.cpp b/tests/validation/CPP/DepthwiseConvolution.cpp index e29d014f77..ad0653846b 100644 --- a/tests/validation/CPP/DepthwiseConvolution.cpp +++ b/tests/validation/CPP/DepthwiseConvolution.cpp @@ -26,8 +26,13 @@ #include "ConvolutionLayer.h" #include "Utils.h" +#include "tests/validation/CPP/Utils.h" +#include "tests/validation/CPP/UtilsQuantizedAsymm.h" +#include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" + namespace arm_compute { namespace test @@ -44,8 +49,8 @@ namespace reference * - Padding, stride and output shape "match" * */ -template -SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info) +template +SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info) { // Create reference SimpleTensor dst{ dst_shape, src.data_type(), 1, src.fixed_point_position() }; @@ -97,8 +102,80 @@ SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTe } coords.set(0, x); coords.set(1, y); - dst[out_pos++] = saturate_cast(val + *static_cast(biases(Coordinates(z)))); + dst[out_pos++] = saturate_cast(val + *static_cast(biases(Coordinates(z)))); + } + } + } + } + + return dst; +} + +template <> +SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &biases, const TensorShape &dst_shape, + const PadStrideInfo &conv_info) +{ + // Create reference + SimpleTensor dst{ dst_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; + + const int input_offset = -src.quantization_info().offset; + const float input_scale = src.quantization_info().scale; + const int weights_offset = -weights.quantization_info().offset; + const float weights_scale = weights.quantization_info().scale; + const int output_offset = dst.quantization_info().offset; + const float output_scale = dst.quantization_info().scale; + + int output_multiplier; + int output_shift; + const float multiplier = input_scale * weights_scale / output_scale; + arm_compute::quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); + + // Compute reference + const int filter_width = weights.shape().x(); + const int filter_height = weights.shape().y(); + const int filter_plane = filter_width * filter_height; + const int input_width = src.shape().x(); + const int input_height = src.shape().y(); + const int input_depth = src.shape().z(); + + const int filter_half_size = filter_width / 2; + const int pad_x = std::min(filter_half_size, static_cast(conv_info.pad().first)); + const int pad_y = std::min(filter_half_size, static_cast(conv_info.pad().second)); + const int minimum_x = -pad_x + filter_half_size; + const int minimum_y = -pad_y + filter_half_size; + + int out_pos = 0; + for(int z = 0; z < input_depth; ++z) + { + int32_t bias_val = *static_cast(biases(Coordinates(z))); + for(int y = minimum_y; y < input_height + pad_y - filter_half_size; y += conv_info.stride().second) + { + for(int x = minimum_x; x < input_width + pad_x - filter_half_size; x += conv_info.stride().first) + { + Coordinates coords(x, y, z); + int filter_offset = filter_plane * z; + + uint32_t val = 0; + for(int j = y - filter_half_size; j <= (y + filter_half_size); ++j) + { + for(int i = x - filter_half_size; i <= (x + filter_half_size); ++i) + { + coords.set(0, i); + coords.set(1, j); + auto in_val = tensor_elem_at(src, coords, BorderMode::CONSTANT, 0); + uint8_t w_val = *(weights.data() + filter_offset); + val += (in_val + input_offset) * (w_val + weights_offset); + ++filter_offset; + } } + val += bias_val; + val = asymm_rounding_divide_by_pow2(asymm_int_mult(val, output_multiplier), output_shift); + val += output_offset; + val = std::max(val, 0); + val = std::min(val, 255); + + // Store the result + dst[out_pos++] = val; } } } diff --git a/tests/validation/CPP/DepthwiseConvolution.h b/tests/validation/CPP/DepthwiseConvolution.h index e8c55b16a8..df743a5b8e 100644 --- a/tests/validation/CPP/DepthwiseConvolution.h +++ b/tests/validation/CPP/DepthwiseConvolution.h @@ -35,8 +35,8 @@ namespace validation { namespace reference { -template -SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info); +template +SimpleTensor depthwise_convolution(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/NEON/DepthwiseConvolution.cpp b/tests/validation/NEON/DepthwiseConvolution.cpp index 6e8aa46ed0..b6719b58e8 100644 --- a/tests/validation/NEON/DepthwiseConvolution.cpp +++ b/tests/validation/NEON/DepthwiseConvolution.cpp @@ -87,18 +87,22 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::da template using NEDepthwiseConvolutionFixture3x3 = DepthwiseConvolutionValidationFixture; +TEST_SUITE(Float) TEST_SUITE(F32) TEST_SUITE(W3x3) -FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthwiseConvolutionFixture3x3, framework::DatasetMode::PRECOMMIT, datasets::SmallDepthwiseConvolutionDataset3x3()) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthwiseConvolutionFixture3x3, framework::DatasetMode::ALL, combine(datasets::SmallDepthwiseConvolutionDataset3x3(), framework::dataset::make("DataType", + DataType::F32))) { validate(Accessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthwiseConvolutionFixture3x3, framework::DatasetMode::NIGHTLY, datasets::LargeDepthwiseConvolutionDataset3x3()) +FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthwiseConvolutionFixture3x3, framework::DatasetMode::NIGHTLY, combine(datasets::LargeDepthwiseConvolutionDataset3x3(), framework::dataset::make("DataType", + DataType::F32))) { validate(Accessor(_target), _reference, tolerance_f32); } TEST_SUITE_END() TEST_SUITE_END() +TEST_SUITE_END() TEST_SUITE_END() TEST_SUITE_END() diff --git a/tests/validation/fixtures/DepthwiseConvolutionFixture.h b/tests/validation/fixtures/DepthwiseConvolutionFixture.h index f49e76c70c..b1d31d657a 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionFixture.h @@ -43,14 +43,22 @@ namespace test namespace validation { template -class DepthwiseConvolutionValidationFixture : public framework::Fixture +class DepthwiseConvolutionValidationGenericFixture : public framework::Fixture { +public: + using TBias = typename std::conditional::type, uint8_t>::value, int32_t, T>::type; + public: template - void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info) + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) { - _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info); - _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info); + _quantization_info = quantization_info; + _data_type = data_type; + + const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; + + _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); + _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); } protected: @@ -59,28 +67,46 @@ protected: { switch(tensor.data_type()) { + case DataType::QASYMM8: + { + std::uniform_int_distribution distribution(0, 10); + library->fill(tensor, distribution, i); + break; + } case DataType::F32: { std::uniform_real_distribution<> distribution(-1.0f, 1.0f); library->fill(tensor, distribution, i); break; } + case DataType::S32: + { + std::uniform_int_distribution distribution(-1000, 1000); + library->fill(tensor, distribution, i); + break; + } default: library->fill_tensor_uniform(tensor, i); } } - TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info) + TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info, + const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor(input_shape, DataType::F32); - TensorType weights = create_tensor(weights_shape, DataType::F32); - TensorType biases = create_tensor(biases_shape, DataType::F32); - TensorType dst = create_tensor(output_shape, DataType::F32); + TensorType src = create_tensor(input_shape, data_type, 1, 0, quantization_info); + TensorType weights = create_tensor(weights_shape, data_type, 1, 0, quantization_info); + TensorType biases = create_tensor(biases_shape, bias_data_type, 1, 0, quantization_info); + TensorType dst = create_tensor(output_shape, data_type, 1, 0, quantization_info); // Create Depthwise Convolution configure function - FunctionType depthwise_convolution; - depthwise_convolution.configure(&src, &weights, &biases, &dst, pad_stride_info); + FunctionType dwc; + dwc.configure(&src, &weights, &biases, &dst, pad_stride_info); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(biases.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); // Allocate tensors src.allocator()->allocate(); @@ -99,16 +125,17 @@ protected: fill(AccessorType(biases), 2); // Compute function - depthwise_convolution.run(); + dwc.run(); return dst; } - SimpleTensor compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info) + SimpleTensor compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info, + const DataType data_type, const DataType bias_data_type, QuantizationInfo quantization_info) { - SimpleTensor src(in_shape, DataType::F32); - SimpleTensor weights(weights_shape, DataType::F32); - SimpleTensor biases(biases_shape, DataType::F32); + SimpleTensor src{ in_shape, data_type, 1, 0, quantization_info }; + SimpleTensor weights{ weights_shape, data_type, 1, 0, quantization_info }; + SimpleTensor biases{ biases_shape, data_type, 1, 0, quantization_info }; fill(src, 0); fill(weights, 1); @@ -117,8 +144,34 @@ protected: return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info); } - TensorType _target{}; - SimpleTensor _reference{}; + TensorType _target{}; + SimpleTensor _reference{}; + DataType _data_type{}; + QuantizationInfo _quantization_info{}; +}; + +template +class DepthwiseConvolutionValidationFixture : public DepthwiseConvolutionValidationGenericFixture +{ +public: + template + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type) + { + DepthwiseConvolutionValidationGenericFixture::setup(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, + data_type, QuantizationInfo()); + } +}; + +template +class DepthwiseConvolutionValidationQuantizedFixture : public DepthwiseConvolutionValidationGenericFixture +{ +public: + template + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) + { + DepthwiseConvolutionValidationGenericFixture::setup(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, + data_type, quantization_info); + } }; } // namespace validation } // namespace test diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index 279a4897eb..1ec4d31304 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -84,6 +84,12 @@ protected: library->fill(tensor, distribution, i); break; } + case DataType::S32: + { + std::uniform_int_distribution distribution(-1000, 1000); + library->fill(tensor, distribution, i); + break; + } default: library->fill_tensor_uniform(tensor, i); } -- cgit v1.2.1