aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2021-07-01 09:14:07 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-07-01 10:22:40 +0000
commit66831659fdef07c428993dccfa5d92416bae1ef9 (patch)
treee7fa1515298e4f0885377485bbafe75388918adb
parent2ef59b9053a402600ad75baebdb4909553044698 (diff)
downloadComputeLibrary-66831659fdef07c428993dccfa5d92416bae1ef9.tar.gz
Add quantization helper functions for OpenCL
Add `T_QUANTIZE8_PER_TENSOR` and `T_QUANTIZE8_PER_CHANNEL` that can be used to perform quantization on tile constructs. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ie8e1efcb895c64715620acf2212b1de9a857ee0a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5891 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/cl_kernels/tile_helpers.h109
1 files changed, 109 insertions, 0 deletions
diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h
index f2d2f26cf2..7910b4ce0e 100644
--- a/src/core/CL/cl_kernels/tile_helpers.h
+++ b/src/core/CL/cl_kernels/tile_helpers.h
@@ -647,5 +647,114 @@
}) \
})
+/** 8-bit quantization with fixed-point scale
+ *
+ * @param[in] SRC_DATA_TYPE SRC data type
+ * @param[in] DST_DATA_TYPE DST data type
+ * @param[in] QUANTIZATION_TYPE Quantization type (PER_TENSOR or PER_CHANNEL)
+ * @param[in] M0 Number of src/dst rows
+ * @param[in] N0 Number of src/dst columns
+ * @param[in] DST_OFFSET Quantization offset used for both the per-tensor and per-channel quantization
+ * @param[in] DST_SHIFT Quantization shift for the per-tensor quantization
+ * @param[in] DST_MULTIPLIER Quantization multiplier for the per-tensor quantization
+ * @param[in] src Input tile
+ * @param[in] dst_multipliers Output multipliers tile for the per-channel quantization
+ * @param[in] dst_shifts Output shift tile for the per-channel quantization
+ * @param[out] dst Output tile
+ */
+#define T_QUANTIZE8(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) T_QUANTIZE8_STR(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst)
+#define T_QUANTIZE8_STR(SRC_DATA_TYPE, DST_DATA_TYPE, QUANTIZATION_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) T_QUANTIZE8_##QUANTIZATION_TYPE(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst)
+
+/** 8-bit per-tensor quantization with fixed-point scale
+ *
+ * @param[in] SRC_DATA_TYPE SRC data type
+ * @param[in] DST_DATA_TYPE DST data type
+ * @param[in] M0 Number of src/dst rows
+ * @param[in] N0 Number of src/dst columns
+ * @param[in] DST_OFFSET Quantization offset
+ * @param[in] DST_SHIFT Quantization shift for the per-tensor quantization
+ * @param[in] DST_MULTIPLIER Quantization multiplier for the per-tensor quantization
+ * @param[in] src Input tile
+ * @param[in] dst_multipliers (unused)
+ * @param[in] dst_shifts (unused)
+ * @param[out] dst Output tile
+ */
+#define T_QUANTIZE8_PER_TENSOR(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) \
+ ({ \
+ LOOP_UNROLLING(int, _m0, 0, 1, M0, \
+ { \
+ LOOP_UNROLLING(int, _n0, 0, 1, N0, \
+ { \
+ SRC_DATA_TYPE _tmp = 0; \
+ SRC_DATA_TYPE _src = src[_m0].s[_n0]; \
+ _src *= select((SRC_DATA_TYPE)1, ((SRC_DATA_TYPE)1 << (SRC_DATA_TYPE)(-DST_SHIFT)), ((SRC_DATA_TYPE)DST_SHIFT < (SRC_DATA_TYPE)0)); \
+ SRC_DATA_TYPE overflow = _src == DST_MULTIPLIER && _src == INT_MIN; \
+ long a_64 = (long)(_src); \
+ long b_64 = (long)(DST_MULTIPLIER); \
+ long ab_64 = a_64 * b_64; \
+ long mask1 = 1 << 30; \
+ long mask2 = 1 - (1 << 30); \
+ long is_positive_or_zero = ab_64 >= 0; \
+ long nudge = select(mask2, mask1, is_positive_or_zero); \
+ SRC_DATA_TYPE ab_x2_high32 = CONVERT((ab_64 + nudge) / (long)(1ll << 31), SRC_DATA_TYPE); \
+ _tmp = select(ab_x2_high32, (SRC_DATA_TYPE)INT_MAX, overflow); \
+ if(DST_SHIFT >= 0) \
+ { \
+ long mask = ((((int)1) << DST_SHIFT) - (int)1); \
+ long threshold = _tmp < (int)0 ? (mask >> 1) + (long)1 : (mask >> 1) + 0; \
+ _tmp = (_tmp & mask) > threshold ? (_tmp >> DST_SHIFT) + (int)1 : (_tmp >> DST_SHIFT); \
+ } \
+ _tmp += DST_OFFSET; \
+ dst[_m0].s[_n0] = CONVERT_SAT(_tmp, DST_DATA_TYPE); \
+ }) \
+ }) \
+ })
+
+/** 8-bit per-channel quantization with fixed-point scale
+ *
+ * @param[in] SRC_DATA_TYPE SRC data type
+ * @param[in] DST_DATA_TYPE DST data type
+ * @param[in] M0 Number of src/dst rows
+ * @param[in] N0 Number of src/dst columns
+ * @param[in] DST_OFFSET Quantization offset
+ * @param[in] DST_SHIFT (unused)
+ * @param[in] DST_MULTIPLIER (unused)
+ * @param[in] src Input tile
+ * @param[in] dst_multipliers Output multipliers tile for the per-channel quantization
+ * @param[in] dst_shifts Output shift tile for the per-channel quantization
+ * @param[out] dst Output tile
+ */
+#define T_QUANTIZE8_PER_CHANNEL(SRC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, src, dst_multipliers, dst_shifts, dst) \
+ ({ \
+ LOOP_UNROLLING(int, _m0, 0, 1, M0, \
+ { \
+ LOOP_UNROLLING(int, _n0, 0, 1, N0, \
+ { \
+ SRC_DATA_TYPE _tmp = 0; \
+ SRC_DATA_TYPE _src = src[_m0].s[_n0]; \
+ SRC_DATA_TYPE _dst_multiplier = dst_multipliers[0].s[_n0]; \
+ SRC_DATA_TYPE _dst_shift = dst_shifts[0].s[_n0]; \
+ _src *= select((SRC_DATA_TYPE)1, ((SRC_DATA_TYPE)1 << (SRC_DATA_TYPE)(-_dst_shift)), ((SRC_DATA_TYPE)_dst_shift < (SRC_DATA_TYPE)0)); \
+ SRC_DATA_TYPE overflow = _src == _dst_multiplier && _src == INT_MIN; \
+ long a_64 = (long)(_src); \
+ long b_64 = (long)(_dst_multiplier); \
+ long ab_64 = a_64 * b_64; \
+ long mask1 = 1 << 30; \
+ long mask2 = 1 - (1 << 30); \
+ long is_positive_or_zero = ab_64 >= 0; \
+ long nudge = select(mask2, mask1, is_positive_or_zero); \
+ SRC_DATA_TYPE ab_x2_high32 = CONVERT((ab_64 + nudge) / (long)(1ll << 31), SRC_DATA_TYPE); \
+ _tmp = select(ab_x2_high32, (SRC_DATA_TYPE)INT_MAX, overflow); \
+ if(_dst_shift >= 0) \
+ { \
+ long mask = ((((int)1) << _dst_shift) - (int)1); \
+ long threshold = _tmp < (int)0 ? (mask >> 1) + (long)1 : (mask >> 1) + 0; \
+ _tmp = (_tmp & mask) > threshold ? (_tmp >> _dst_shift) + (int)1 : (_tmp >> _dst_shift); \
+ } \
+ _tmp += DST_OFFSET; \
+ dst[_m0].s[_n0] = CONVERT_SAT(_tmp, DST_DATA_TYPE); \
+ }) \
+ }) \
+ })
// clang-format on
// *INDENT-ON* \ No newline at end of file