From 66831659fdef07c428993dccfa5d92416bae1ef9 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 1 Jul 2021 09:14:07 +0100 Subject: Add quantization helper functions for OpenCL Add `T_QUANTIZE8_PER_TENSOR` and `T_QUANTIZE8_PER_CHANNEL` that can be used to perform quantization on tile constructs. Signed-off-by: Georgios Pinitas Change-Id: Ie8e1efcb895c64715620acf2212b1de9a857ee0a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5891 Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/CL/cl_kernels/tile_helpers.h | 109 ++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h index f2d2f26cf2..7910b4ce0e 100644 --- a/src/core/CL/cl_kernels/tile_helpers.h +++ b/src/core/CL/cl_kernels/tile_helpers.h @@ -647,5 +647,114 @@ }) \ }) +/** 8-bit quantization with fixed-point scale + * + * @param[in] SRC_DATA_TYPE SRC data type + * @param[in] DST_DATA_TYPE DST data type + * @param[in] QUANTIZATION_TYPE Quantization type (PER_TENSOR or PER_CHANNEL) + * @param[in] M0 Number of src/dst rows + * @param[in] N0 Number of src/dst columns + * @param[in] DST_OFFSET Quantization offset used for both the per-tensor and per-channel quantization + * @param[in] DST_SHIFT Quantization shift for the per-tensor quantization + * @param[in] DST_MULTIPLIER Quantization multiplier for the per-tensor quantization + * @param[in] src Input tile + * @param[in] dst_multipliers Output multipliers tile for the per-channel quantization + * @param[in] dst_shifts Output shift tile for the per-channel quantization + * @param[out] dst Output tile + */ +#define T_QUANTIZE8(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) T_QUANTIZE8_STR(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) +#define T_QUANTIZE8_STR(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) T_QUANTIZE8_##QUANTIZATION_TYPE(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) + +/** 8-bit per-tensor quantization with fixed-point scale + * + * @param[in] SRC_DATA_TYPE SRC data type + * @param[in] DST_DATA_TYPE DST data type + * @param[in] M0 Number of src/dst rows + * @param[in] N0 Number of src/dst columns + * @param[in] DST_OFFSET Quantization offset + * @param[in] DST_SHIFT Quantization shift for the per-tensor quantization + * @param[in] DST_MULTIPLIER Quantization multiplier for the per-tensor quantization + * @param[in] src Input tile + * @param[in] dst_multipliers (unused) + * @param[in] dst_shifts (unused) + * @param[out] dst Output tile + */ +#define T_QUANTIZE8_PER_TENSOR(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + LOOP_UNROLLING(int, _n0, 0, 1, N0, \ + { \ + SRC_DATA_TYPE _tmp = 0; \ + SRC_DATA_TYPE _src = src[_m0].s[_n0]; \ + _src *= select((SRC_DATA_TYPE)1, ((SRC_DATA_TYPE)1 << (SRC_DATA_TYPE)(-DST_SHIFT)), ((SRC_DATA_TYPE)DST_SHIFT < (SRC_DATA_TYPE)0)); \ + SRC_DATA_TYPE overflow = _src == DST_MULTIPLIER && _src == INT_MIN; \ + long a_64 = (long)(_src); \ + long b_64 = (long)(DST_MULTIPLIER); \ + long ab_64 = a_64 * b_64; \ + long mask1 = 1 << 30; \ + long mask2 = 1 - (1 << 30); \ + long is_positive_or_zero = ab_64 >= 0; \ + long nudge = select(mask2, mask1, is_positive_or_zero); \ + SRC_DATA_TYPE ab_x2_high32 = CONVERT((ab_64 + nudge) / (long)(1ll << 31), SRC_DATA_TYPE); \ + _tmp = select(ab_x2_high32, (SRC_DATA_TYPE)INT_MAX, overflow); \ + if(DST_SHIFT >= 0) \ + { \ + long mask = ((((int)1) << DST_SHIFT) - (int)1); \ + long threshold = _tmp < (int)0 ? (mask >> 1) + (long)1 : (mask >> 1) + 0; \ + _tmp = (_tmp & mask) > threshold ? (_tmp >> DST_SHIFT) + (int)1 : (_tmp >> DST_SHIFT); \ + } \ + _tmp += DST_OFFSET; \ + dst[_m0].s[_n0] = CONVERT_SAT(_tmp, DST_DATA_TYPE); \ + }) \ + }) \ + }) + +/** 8-bit per-channel quantization with fixed-point scale + * + * @param[in] SRC_DATA_TYPE SRC data type + * @param[in] DST_DATA_TYPE DST data type + * @param[in] M0 Number of src/dst rows + * @param[in] N0 Number of src/dst columns + * @param[in] DST_OFFSET Quantization offset + * @param[in] DST_SHIFT (unused) + * @param[in] DST_MULTIPLIER (unused) + * @param[in] src Input tile + * @param[in] dst_multipliers Output multipliers tile for the per-channel quantization + * @param[in] dst_shifts Output shift tile for the per-channel quantization + * @param[out] dst Output tile + */ +#define T_QUANTIZE8_PER_CHANNEL(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + LOOP_UNROLLING(int, _n0, 0, 1, N0, \ + { \ + SRC_DATA_TYPE _tmp = 0; \ + SRC_DATA_TYPE _src = src[_m0].s[_n0]; \ + SRC_DATA_TYPE _dst_multiplier = dst_multipliers[0].s[_n0]; \ + SRC_DATA_TYPE _dst_shift = dst_shifts[0].s[_n0]; \ + _src *= select((SRC_DATA_TYPE)1, ((SRC_DATA_TYPE)1 << (SRC_DATA_TYPE)(-_dst_shift)), ((SRC_DATA_TYPE)_dst_shift < (SRC_DATA_TYPE)0)); \ + SRC_DATA_TYPE overflow = _src == _dst_multiplier && _src == INT_MIN; \ + long a_64 = (long)(_src); \ + long b_64 = (long)(_dst_multiplier); \ + long ab_64 = a_64 * b_64; \ + long mask1 = 1 << 30; \ + long mask2 = 1 - (1 << 30); \ + long is_positive_or_zero = ab_64 >= 0; \ + long nudge = select(mask2, mask1, is_positive_or_zero); \ + SRC_DATA_TYPE ab_x2_high32 = CONVERT((ab_64 + nudge) / (long)(1ll << 31), SRC_DATA_TYPE); \ + _tmp = select(ab_x2_high32, (SRC_DATA_TYPE)INT_MAX, overflow); \ + if(_dst_shift >= 0) \ + { \ + long mask = ((((int)1) << _dst_shift) - (int)1); \ + long threshold = _tmp < (int)0 ? (mask >> 1) + (long)1 : (mask >> 1) + 0; \ + _tmp = (_tmp & mask) > threshold ? (_tmp >> _dst_shift) + (int)1 : (_tmp >> _dst_shift); \ + } \ + _tmp += DST_OFFSET; \ + dst[_m0].s[_n0] = CONVERT_SAT(_tmp, DST_DATA_TYPE); \ + }) \ + }) \ + }) // clang-format on // *INDENT-ON* \ No newline at end of file -- cgit v1.2.1