From 3f632f3f16e29ebeb7065b30008060fd4bfd09f1 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 22 Aug 2019 16:52:00 +0100 Subject: COMPMID-2418: CLDequantizationLayer support for QASYMM8_PER_CHANNEL Add support for QASYMM8_PER_CHANNEL in CLDequantiazationLayer. Added tests for NHWC and also updated NEON code to work with NHWC data layout. Cleaned up the reference implementation. Change-Id: Ic1d51f16f7f625503fffdbbb66f6487aa588f08c Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1828 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas --- .../core/CL/kernels/CLDequantizationLayerKernel.h | 4 +- arm_compute/core/NEON/NEAsymm.h | 22 ++++ arm_compute/runtime/CL/CLTensorAllocator.h | 4 +- .../runtime/CL/functions/CLDequantizationLayer.h | 4 +- src/core/CL/CLHelpers.cpp | 8 ++ src/core/CL/CLKernelLibrary.cpp | 2 + src/core/CL/cl_kernels/dequantization_layer.cl | 134 ++++++++++++++++++++- .../CL/kernels/CLDequantizationLayerKernel.cpp | 39 ++++-- .../NEON/kernels/NEDequantizationLayerKernel.cpp | 64 +++++++++- src/runtime/CL/CLTensorAllocator.cpp | 12 +- tests/validation/CL/DequantizationLayer.cpp | 39 ++++-- tests/validation/NEON/DequantizationLayer.cpp | 39 ++++-- .../fixtures/DequantizationLayerFixture.h | 16 ++- tests/validation/reference/DequantizationLayer.cpp | 18 +-- 14 files changed, 346 insertions(+), 59 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLDequantizationLayerKernel.h b/arm_compute/core/CL/kernels/CLDequantizationLayerKernel.h index 0ee5a13638..830d7518ce 100644 --- a/arm_compute/core/CL/kernels/CLDequantizationLayerKernel.h +++ b/arm_compute/core/CL/kernels/CLDequantizationLayerKernel.h @@ -48,13 +48,13 @@ public: ~CLDequantizationLayerKernel() = default; /** Set the input, output, min and max. * - * @param[in] input Source tensor. Data types supported: QASYMM8/QSYMM8/QSYMM16. + * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_PER_CHANNEL/QSYMM8/QSYMM16. * @param[out] output Destination tensor. Data types supported: F16/F32. */ void configure(const ICLTensor *input, ICLTensor *output); /** Static function to check if given info will lead to a valid configuration of @ref CLDequantizationLayerKernel * - * @param[in] input Input tensor info. Data types supported: QASYMM8/QSYMM8/QSYMM16. + * @param[in] input Input tensor info. Data types supported: QASYMM8/QASYMM8_PER_CHANNEL/QSYMM8/QSYMM16. * @param[in] output Output tensor info. Data types supported: F16/F32. * * @return a status diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h index 981c7b075c..f2d20d373a 100644 --- a/arm_compute/core/NEON/NEAsymm.h +++ b/arm_compute/core/NEON/NEAsymm.h @@ -226,6 +226,28 @@ inline float32x4x4_t vdequantize(const uint8x16_t &qv, float scale, int32_t offs return vdequantized_input; } +/** Dequantize following an asymmetric quantization scheme a neon vector holding 16 quantized values. + * + * @param[in] qv Input values to be dequantized. + * @param[in] vscale Vector containing quantization scaling factors. + * @param[in] voffset Vector containing quantization offset. + * + * @return Dequantized values in a neon vector + */ +inline float32x4x4_t vdequantize(const uint8x16_t &qv, const float32x4x4_t vscale, const int32x4x4_t voffset) +{ + const float32x4x4_t vdequantized_input = + { + { + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(qv))))), voffset.val[0])), vscale.val[0]), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(qv))))), voffset.val[1])), vscale.val[1]), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(qv))))), voffset.val[2])), vscale.val[2]), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(qv))))), voffset.val[3])), vscale.val[3]), + } + }; + return vdequantized_input; +} + /** Dequantize following a symmetric quantization scheme a neon vector holding 16 quantized values. * * @param[in] qv Input values to be dequantized. diff --git a/arm_compute/runtime/CL/CLTensorAllocator.h b/arm_compute/runtime/CL/CLTensorAllocator.h index 982cc51274..f7800d39f8 100644 --- a/arm_compute/runtime/CL/CLTensorAllocator.h +++ b/arm_compute/runtime/CL/CLTensorAllocator.h @@ -146,8 +146,8 @@ private: CLMemory _memory; /**< OpenCL memory */ uint8_t *_mapping; /**< Pointer to the CPU mapping of the OpenCL buffer. */ CLTensor *_owner; /**< Owner of the allocator */ - CLFloatArray _scale; - CLInt32Array _offset; + CLFloatArray _scale; /**< Scales array in case of quantized per channel data type */ + CLInt32Array _offset; /**< Offsets array in case of quantized per channel data type */ }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLTENSORALLOCATOR_H__ */ diff --git a/arm_compute/runtime/CL/functions/CLDequantizationLayer.h b/arm_compute/runtime/CL/functions/CLDequantizationLayer.h index ade589d79e..c519311fb1 100644 --- a/arm_compute/runtime/CL/functions/CLDequantizationLayer.h +++ b/arm_compute/runtime/CL/functions/CLDequantizationLayer.h @@ -40,13 +40,13 @@ public: /** Set the input and output tensors. * * @param[in] input Source tensor with at least 3 dimensions. The dimensions over the third will be interpreted as batches. - * Data types supported: QASYMM8/QSYMM8/QSYMM16. + * Data types supported: QASYMM8/QASYMM8_PER_CHANNEL/QSYMM8/QSYMM16. * @param[out] output Destination tensor with the same dimensions of input. Data type supported: F16/F32. */ void configure(const ICLTensor *input, ICLTensor *output); /** Static function to check if given info will lead to a valid configuration of @ref CLDequantizationLayer * - * @param[in] input Input tensor info. Data types supported: QASYMM8/QSYMM8/QSYMM16. + * @param[in] input Input tensor info. Data types supported: QASYMM8/QASYMM8_PER_CHANNEL/QSYMM8/QSYMM16. * @param[in] output Output tensor info. Data type supported: F16/F32. * * @return a status diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index e80349e486..bb3cf7fda2 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -38,9 +38,11 @@ std::string get_cl_type_from_data_type(const DataType &dt) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return "uchar"; case DataType::S8: case DataType::QSYMM8: + case DataType::QSYMM8_PER_CHANNEL: return "char"; case DataType::U16: return "ushort"; @@ -71,9 +73,11 @@ std::string get_cl_select_type_from_data_type(const DataType &dt) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return "uchar"; case DataType::S8: case DataType::QSYMM8: + case DataType::QSYMM8_PER_CHANNEL: return "char"; case DataType::U16: return "ushort"; @@ -104,6 +108,8 @@ std::string get_data_size_from_data_type(const DataType &dt) case DataType::S8: case DataType::QSYMM8: case DataType::QASYMM8: + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: return "8"; case DataType::U16: case DataType::S16: @@ -246,6 +252,8 @@ size_t preferred_vector_width(const cl::Device &device, const DataType dt) case DataType::S8: case DataType::QASYMM8: case DataType::QSYMM8: + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: return device.getInfo(); case DataType::U16: case DataType::S16: diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index 4b3b37c3da..d1500f00b5 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -236,6 +236,8 @@ const std::map CLKernelLibrary::_kernel_program_map = { "depthwise_im2col", "depthwise_convolution.cl" }, { "depthwise_vector_to_tensor", "depthwise_convolution.cl" }, { "dequantization_layer", "dequantization_layer.cl" }, + { "dequantization_layer_per_channel_nhwc", "dequantization_layer.cl" }, + { "dequantization_layer_per_channel_nchw", "dequantization_layer.cl" }, { "derivative", "derivative.cl" }, { "dilate", "dilate.cl" }, { "direct_convolution1x1", "direct_convolution1x1.cl" }, diff --git a/src/core/CL/cl_kernels/dequantization_layer.cl b/src/core/CL/cl_kernels/dequantization_layer.cl index 7d87dc6a2d..5826847a5e 100644 --- a/src/core/CL/cl_kernels/dequantization_layer.cl +++ b/src/core/CL/cl_kernels/dequantization_layer.cl @@ -87,5 +87,137 @@ __kernel void dequantization_layer( *((__global DATA_TYPE_DST *)(output.ptr)) = (DATA_TYPE_DST)((float)((int)(*((__global DATA_TYPE_SRC *)(input.ptr))) - (int)(OFFSET)) * (float)(SCALE)); #endif // defined(LAST_ACCESSED_X) } - #endif // defined(VEC_SIZE) && defined(DATA_TYPE_SRC) && defined(DATA_TYPE_DST) && defined(SCALE) && defined(OFFSET) + +#if defined(VEC_SIZE) && defined(DATA_TYPE_SRC) && defined(DATA_TYPE_DST) +/** This performs per channel dequantization of 8-bit unsigned integers to floating point. (NCHW) + * + * @note Source datatype should be given as a preprocessor argument using -DDATA_TYPE_SRC=type. e.g. -DDATA_TYPE_SRC=char + * @note Destination datatype should be given as a preprocessor argument using -DDATA_TYPE_DST=type. e.g. -DDATA_TYPE_DST=float + * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16 + * + * @param[in] input_ptr Pointer to the source tensor. Supported data types: QASYMM8_PER_CHANNEL + * @param[in] input_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] input_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] input_step_y input_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] input_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] input_step_z input_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] input_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[out] output_ptr Pointer to the destination tensor. Supported data types: F16/F32 + * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] output_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] output_step_z output_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] scale Pointer to buffer with the per channel quantized scales + * @param[in] offset Pointer to buffer with the per channel quantized offsets + */ +__kernel void dequantization_layer_per_channel_nchw( + TENSOR3D_DECLARATION(input), + TENSOR3D_DECLARATION(output), + __global float *scale, + __global int *offset) +{ + // Get pixels pointer + Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input); + Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output); + +#if defined(LAST_ACCESSED_X) + // Check if access on width gets out of bounds + // If it does shift access vector to access elements within bounds + const int xi = (int)(get_global_id(0) * VEC_SIZE); + input.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * input_stride_x; + output.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * output_stride_x; + + // Load data + VEC_DATA_TYPE(int, VEC_SIZE) + val = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE_SRC *)input.ptr), VEC_DATA_TYPE(int, VEC_SIZE)); + + // Create scale and offset vectors + const VEC_DATA_TYPE(float, VEC_SIZE) + vscale = scale[get_global_id(2)]; + + const VEC_DATA_TYPE(int, VEC_SIZE) + voffset = offset[get_global_id(2)]; + + // Dequantize + VEC_DATA_TYPE(float, VEC_SIZE) + res = vscale * CONVERT((val - voffset), VEC_DATA_TYPE(float, VEC_SIZE)); + + // Store result + VSTORE(VEC_SIZE) + (CONVERT(res, VEC_DATA_TYPE(DATA_TYPE_DST, VEC_SIZE)), 0, (__global DATA_TYPE_DST *)output.ptr); +#else // !defined(LAST_ACCESSED_X) + *((__global DATA_TYPE_DST *)(output.ptr)) = (DATA_TYPE_DST)((float)((int)(*((__global DATA_TYPE_SRC *)(input.ptr))) - offset[get_global_id(2)]) * scale[get_global_id(2)]); +#endif // defined(LAST_ACCESSED_X) +} +/** This performs per channel dequantization of 8-bit unsigned integers to floating point. (NHWC) + * + * @note Source datatype should be given as a preprocessor argument using -DDATA_TYPE_SRC=type. e.g. -DDATA_TYPE_SRC=char + * @note Destination datatype should be given as a preprocessor argument using -DDATA_TYPE_DST=type. e.g. -DDATA_TYPE_DST=float + * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16 + * + * @param[in] input_ptr Pointer to the source tensor. Supported data types: QASYMM8_PER_CHANNEL + * @param[in] input_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] input_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] input_step_y input_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] input_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] input_step_z input_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] input_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[out] output_ptr Pointer to the destination tensor. Supported data types: F16/F32 + * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] output_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] output_step_z output_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] scale Pointer to buffer with the per channel quantized scales + * @param[in] offset Pointer to buffer with the per channel quantized offsets + */ +__kernel void dequantization_layer_per_channel_nhwc( + TENSOR3D_DECLARATION(input), + TENSOR3D_DECLARATION(output), + __global float *scale, + __global int *offset) +{ + // Get pixels pointer + Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input); + Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output); + +#if defined(LAST_ACCESSED_X) + // Check if access on width gets out of bounds + // If it does shift access vector to access elements within bounds + const int xi = (int)(get_global_id(0) * VEC_SIZE); + input.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * input_stride_x; + output.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * output_stride_x; + scale -= max(xi - (int)LAST_ACCESSED_X, 0); + offset -= max(xi - (int)LAST_ACCESSED_X, 0); + + // Load data + VEC_DATA_TYPE(int, VEC_SIZE) + val = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE_SRC *)input.ptr), VEC_DATA_TYPE(int, VEC_SIZE)); + + // Create scale and offset vectors + const VEC_DATA_TYPE(float, VEC_SIZE) + vscale = VLOAD(VEC_SIZE)(0, &scale[xi]); + + const VEC_DATA_TYPE(int, VEC_SIZE) + voffset = VLOAD(VEC_SIZE)(0, &offset[xi]); + + // Dequantize + VEC_DATA_TYPE(float, VEC_SIZE) + res = vscale * CONVERT((val - voffset), VEC_DATA_TYPE(float, VEC_SIZE)); + + // Store result + VSTORE(VEC_SIZE) + (CONVERT(res, VEC_DATA_TYPE(DATA_TYPE_DST, VEC_SIZE)), 0, (__global DATA_TYPE_DST *)output.ptr); +#else // !defined(LAST_ACCESSED_X) + *((__global DATA_TYPE_DST *)(output.ptr)) = (DATA_TYPE_DST)((float)((int)(*((__global DATA_TYPE_SRC *)(input.ptr))) - offset[get_global_id(0)]) * scale[get_global_id(0)]); +#endif // defined(LAST_ACCESSED_X) +} +#endif // defined(VEC_SIZE) && defined(DATA_TYPE_SRC) && defined(DATA_TYPE_DST) diff --git a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp index 10a2878be7..3ec0b87636 100644 --- a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp @@ -40,7 +40,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8, DataType::QSYMM16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_PER_CHANNEL, DataType::QSYMM8, DataType::QSYMM16); if(output->tensor_shape().total_size() > 0) { @@ -95,20 +95,31 @@ void CLDequantizationLayerKernel::configure(const ICLTensor *input, ICLTensor *o } ICLKernel::configure_internal(win); - const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); - const int qoffset = is_data_type_quantized_asymmetric(input->info()->data_type()) ? qinfo.offset : 0; + const bool is_quantized_per_channel = is_data_type_quantized_per_channel(input->info()->data_type()); + std::string kernel_name = "dequantization_layer"; // Create kernel CLBuildOptions build_opts; - build_opts.add_option("-DSCALE=" + float_to_string_with_full_precision(qinfo.scale)); - build_opts.add_option("-DOFFSET=" + support::cpp11::to_string(qoffset)); + if(!is_quantized_per_channel) + { + const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); + const int qoffset = is_data_type_quantized_asymmetric(input->info()->data_type()) ? qinfo.offset : 0; + build_opts.add_option("-DSCALE=" + float_to_string_with_full_precision(qinfo.scale)); + build_opts.add_option("-DOFFSET=" + support::cpp11::to_string(qoffset)); + } + else + { + kernel_name += "_per_channel"; + kernel_name += input->info()->data_layout() == DataLayout::NCHW ? "_nchw" : "_nhwc"; + } + build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); build_opts.add_option("-DDATA_TYPE_SRC=" + get_cl_type_from_data_type(input->info()->data_type())); build_opts.add_option("-DDATA_TYPE_DST=" + get_cl_type_from_data_type(output->info()->data_type())); build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max(output_width_x - vec_size_x, 0))); // Create kernel name - _kernel = static_cast(CLKernelLibrary::get().create_kernel("dequantization_layer", build_opts.options())); + _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); } Status CLDequantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output) @@ -123,8 +134,18 @@ void CLDequantizationLayerKernel::run(const Window &window, cl::CommandQueue &qu ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), 3); - Window slice = window_collapsed.first_slice_window_3D(); + const bool is_quantized_per_channel = is_data_type_quantized_per_channel(_input->info()->data_type()); + + // Collapse windo + Window new_window = is_quantized_per_channel ? window.collapse_if_possible(ICLKernel::window(), 4) : window.collapse_if_possible(ICLKernel::window(), 3); + Window slice = new_window.first_slice_window_3D(); + + if(is_quantized_per_channel) + { + unsigned int idx = num_arguments_per_3D_tensor() * 2; //Skip the input and output parameters + _kernel.setArg(idx++, _input->quantization().scale->cl_buffer()); + _kernel.setArg(idx++, _input->quantization().offset->cl_buffer()); + } do { @@ -133,6 +154,6 @@ void CLDequantizationLayerKernel::run(const Window &window, cl::CommandQueue &qu add_3D_tensor_argument(idx, _output, slice); enqueue(queue, *this, slice, lws_hint()); } - while(window_collapsed.slide_window_slice_3D(slice)); + while(new_window.slide_window_slice_3D(slice)); } } // namespace arm_compute \ No newline at end of file diff --git a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp index d880c80d82..49de3ec8b3 100644 --- a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp @@ -160,7 +160,7 @@ void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Win } template -void run_dequantization_qasymm8_per_channel(const ITensor *input, ITensor *output, const Window &window) +void run_dequantization_qasymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window) { const std::vector scale = input->info()->quantization_info().scale(); const std::vector offset = input->info()->quantization_info().offset(); @@ -201,6 +201,66 @@ void run_dequantization_qasymm8_per_channel(const ITensor *input, ITensor *outpu in, out); } +template +void run_dequantization_qasymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window) +{ + const std::vector scale = input->info()->quantization_info().scale(); + const std::vector offset = input->info()->quantization_info().offset(); + + const int window_step_x = 16; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + // Reset first dimension to handle tail calculations manually + Window win(window); + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win); + Iterator out(output, win); + + execute_window_loop(win, [&](const Coordinates & id) + { + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto out_ptr = reinterpret_cast(out.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t vscale = + { + { + scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], + scale[x + 4], scale[x + 5], scale[x + 6], scale[x + 7], + scale[x + 8], scale[x + 9], scale[x + 10], scale[x + 11], + scale[x + 12], scale[x + 13], scale[x + 14], scale[x + 15] + } + }; + const int32x4x4_t voffset = + { + { + offset[x + 0], offset[x + 1], offset[x + 2], offset[x + 3], + offset[x + 4], offset[x + 5], offset[x + 6], offset[x + 7], + offset[x + 8], offset[x + 9], offset[x + 10], offset[x + 11], + offset[x + 12], offset[x + 13], offset[x + 14], offset[x + 15] + } + }; + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, vscale, voffset); + + store_result(reinterpret_cast(out_ptr + x), vdeq); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + uint8_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast(dequantize(val, scale[x], offset[x])); + } + }, + in, out); +} + template void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window) { @@ -294,7 +354,7 @@ void run_dequantization_core(const ITensor *input, ITensor *output, const Window run_dequantization_qasymm8(input, output, window); break; case DataType::QASYMM8_PER_CHANNEL: - run_dequantization_qasymm8_per_channel(input, output, window); + input->info()->data_layout() == DataLayout::NHWC ? run_dequantization_qasymm8_per_channel_nhwc(input, output, window) : run_dequantization_qasymm8_per_channel_nchw(input, output, window); break; case DataType::QSYMM8: run_dequantization_qsymm8(input, output, window); diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp index 028a764fc2..51caf69297 100644 --- a/src/runtime/CL/CLTensorAllocator.cpp +++ b/src/runtime/CL/CLTensorAllocator.cpp @@ -79,8 +79,6 @@ void clear_quantization_arrays(CLFloatArray &scale, CLInt32Array &offset) * @param[in, out] offset Quantization offset array * @param[in] qinfo Quantization info * @param[in] pad_size Pad size to use in case array needs to be padded for computation purposes - * - * @return A pair (scale, offset) containing the respective allocated and filled arrays */ void populate_quantization_info(CLFloatArray &scale, CLInt32Array &offset, const QuantizationInfo &qinfo, size_t pad_size) { @@ -93,6 +91,16 @@ void populate_quantization_info(CLFloatArray &scale, CLInt32Array &offset, const scale = CLFloatArray(num_elements + pad_size); scale.resize(num_elements); CLScheduler::get().queue().enqueueWriteBuffer(scale.cl_buffer(), CL_TRUE, 0, num_elements * element_size, qinfo.scale().data()); + + if(!qinfo.offset().empty()) + { + // Create offset array + const std::vector &qoffset = qinfo.offset(); + const size_t offset_element_size = sizeof(std::remove_reference::type::value_type); + offset = CLInt32Array(num_elements + pad_size); + offset.resize(num_elements); + CLScheduler::get().queue().enqueueWriteBuffer(offset.cl_buffer(), CL_TRUE, 0, num_elements * offset_element_size, qinfo.offset().data()); + } } } // namespace diff --git a/tests/validation/CL/DequantizationLayer.cpp b/tests/validation/CL/DequantizationLayer.cpp index 2ef8c60998..acc0022d3e 100644 --- a/tests/validation/CL/DequantizationLayer.cpp +++ b/tests/validation/CL/DequantizationLayer.cpp @@ -41,6 +41,33 @@ namespace test { namespace validation { +namespace +{ +const auto dataset_quant_f32 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_f16 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_per_channel_f32 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_per_channel_f16 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_nightly_f32 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_nightly_f16 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_per_channel_nightly_f32 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_per_channel_nightly_f16 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +} // namespace TEST_SUITE(CL) TEST_SUITE(DequantizationLayer) @@ -97,14 +124,12 @@ template using CLDequantizationLayerFixture = DequantizationValidationFixture; TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, concat(dataset_quant_f16, dataset_quant_per_channel_f16)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, concat(dataset_quant_nightly_f16, dataset_quant_per_channel_nightly_f16)) { // Validate output validate(CLAccessor(_target), _reference); @@ -112,14 +137,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture, framework:: TEST_SUITE_END() // FP16 TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, concat(dataset_quant_f32, dataset_quant_per_channel_f32)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, concat(dataset_quant_nightly_f32, dataset_quant_per_channel_nightly_f32)) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/NEON/DequantizationLayer.cpp b/tests/validation/NEON/DequantizationLayer.cpp index 005ed6900c..0dce76a933 100644 --- a/tests/validation/NEON/DequantizationLayer.cpp +++ b/tests/validation/NEON/DequantizationLayer.cpp @@ -48,6 +48,31 @@ const auto data_types = framework::dataset::make("DataType", { DataType::F16, Da #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ const auto data_types = framework::dataset::make("DataType", { DataType::F32 }); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +const auto dataset_quant_f32 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_f16 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_per_channel_f32 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_per_channel_f16 = combine(combine(combine(datasets::SmallShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_nightly_f32 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_nightly_f16 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto dataset_quant_per_channel_nightly_f32 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); +const auto dataset_quant_per_channel_nightly_f16 = combine(combine(combine(datasets::LargeShapes(), datasets::QuantizedPerChannelTypes()), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })); } // namespace TEST_SUITE(NEON) @@ -107,14 +132,12 @@ using NEDequantizationLayerFixture = DequantizationValidationFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, concat(dataset_quant_f16, dataset_quant_per_channel_f16)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, concat(dataset_quant_nightly_f16, dataset_quant_per_channel_nightly_f16)) { // Validate output validate(Accessor(_target), _reference); @@ -123,16 +146,12 @@ TEST_SUITE_END() // FP16 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), concat(datasets::QuantizedTypes(), - datasets::QuantizedPerChannelTypes())), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture, framework::DatasetMode::PRECOMMIT, concat(dataset_quant_f32, dataset_quant_per_channel_f32)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), concat(datasets::QuantizedTypes(), - datasets::QuantizedPerChannelTypes())), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture, framework::DatasetMode::NIGHTLY, concat(dataset_quant_nightly_f32, dataset_quant_per_channel_nightly_f32)) { // Validate output validate(Accessor(_target), _reference); diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h index 4842ee1c59..c7a818fcc7 100644 --- a/tests/validation/fixtures/DequantizationLayerFixture.h +++ b/tests/validation/fixtures/DequantizationLayerFixture.h @@ -32,6 +32,7 @@ #include "tests/IAccessor.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" +#include "tests/validation/Helpers.h" #include "tests/validation/reference/DequantizationLayer.h" #include @@ -47,10 +48,10 @@ class DequantizationValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype) + void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype, DataLayout data_layout) { _quantization_info = generate_quantization_info(src_data_type, shape.z()); - _target = compute_target(shape, src_data_type, dst_datatype); + _target = compute_target(shape, src_data_type, dst_datatype, data_layout); _reference = compute_reference(shape, src_data_type); } @@ -61,11 +62,16 @@ protected: library->fill_tensor_uniform(tensor, 0); } - TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype) + TensorType compute_target(TensorShape shape, DataType src_data_type, DataType dst_datatype, DataLayout data_layout) { + if(data_layout == DataLayout::NHWC) + { + permute(shape, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - TensorType src = create_tensor(shape, src_data_type, 1, _quantization_info); - TensorType dst = create_tensor(shape, dst_datatype); + TensorType src = create_tensor(shape, src_data_type, 1, _quantization_info, data_layout); + TensorType dst = create_tensor(shape, dst_datatype, 1, QuantizationInfo(), data_layout); // Create and configure function FunctionType dequantization_layer; diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp index 74686bdaaf..69a49a3d6d 100644 --- a/tests/validation/reference/DequantizationLayer.cpp +++ b/tests/validation/reference/DequantizationLayer.cpp @@ -50,9 +50,9 @@ TOut dequantize(int16_t val, const UniformQuantizationInfo qinfo) { return static_cast(dequantize_qsymm16(val, qinfo)); } - +} // namespace template -SimpleTensor dequantization_layer_nchw(const SimpleTensor &src) +SimpleTensor dequantization_layer(const SimpleTensor &src) { const DataType src_data_type = src.data_type(); const DataType dst_data_type = std::is_same::value ? DataType::F32 : DataType::F16; @@ -97,20 +97,6 @@ SimpleTensor dequantization_layer_nchw(const SimpleTensor &src) return dst; } -} // namespace -template -SimpleTensor dequantization_layer(const SimpleTensor &src) -{ - if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL) - { - SimpleTensor src_nchw = reference::permute(src, PermutationVector(1U, 2U, 0U)); - return reference::permute(dequantization_layer_nchw(src_nchw), PermutationVector(2U, 0U, 1U)); - } - else - { - return dequantization_layer_nchw(src); - } -} template SimpleTensor dequantization_layer(const SimpleTensor &src); template SimpleTensor dequantization_layer(const SimpleTensor &src); -- cgit v1.2.1