From 532ce2c84dd24cb0c5064a3d2e5c7b4094df0e01 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 14 Sep 2023 09:13:49 +0100 Subject: Separate the output quantization calculation logic from matmul This patch generalizes the suggested output quantization calculation to any operation that employs a dot product between two vectors, i.e. c = sum_k(a_k * b_k) + d It also consider and suggests min/max boundaries for random S32 bias generation, depending on the accumulation result. MatMulKernelFixture is modified to use this interface. Signed-off-by: Gunes Bayir Change-Id: Ibb528261bb0310015967e11bd7ccd9ed9cff8479 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10312 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: SiCong Li Benchmark: Arm Jenkins --- tests/validation/Helpers.cpp | 135 +++++++++++++++++++----- tests/validation/Helpers.h | 48 ++++++++- tests/validation/fixtures/MatMulKernelFixture.h | 47 +++++---- 3 files changed, 184 insertions(+), 46 deletions(-) diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp index 5e02cc843c..2f273e7042 100644 --- a/tests/validation/Helpers.cpp +++ b/tests/validation/Helpers.cpp @@ -26,6 +26,8 @@ #include #include +#include +#include namespace arm_compute { @@ -350,9 +352,47 @@ void add_padding_x(std::initializer_list tensors, const DataLayout &d } } -QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, const QuantizationInfo &b_q_info, int m, int n, int k, DataType data_type) +QuantizationHint suggest_matmul_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info, + const QuantizationInfo &rhs_q_info, int32_t m, int32_t n, int32_t k, DataType data_type, + float bias_fraction) { ARM_COMPUTE_UNUSED(m, n); + + /** Quantization Setup of matrix multiplication + * + * We have a matrix multiplication of the form C = A * B + D + * where A is (m X k), B is (k x n) and C is therefore (m x n). + * The bias, D is (1 x n). + * + * If we have some distributional statistics of A, B and D, i.e. mean and variance, + * we can estimate the mean and variance of a single value in C matrix and pick + * good scale and offset values for the output and have non-saturated tests. + * + * Each element in the output matrix can be calculated as follows: + * C_ij = sum_k(A_ik * B_kj) + D_j + * + * Note: All possible A_ik, B_kj, D_j random variables are assumed mutually independent. + * Note: In quantized operators, bias is an integer. But, its quantization scale is + * assumed to be equal to lhs_scale * rhs_scale, and offset equal to 0. + * Note: Since, bias is an integer that should be given as input, we need to pick responsible + * values when adding it on top of the summation. This is where "bias_fraction" comes + * into play. Based on the fraction given, we also return suggested bias range (min/max) + * for not saturating the output. + * + * Because all random variables are mutually independent, any C_ij has the same statistics, + * which is why we return a single destination quantization info object; which is why we can + * resort to a more general calculation explained in suggest_mac_dst_q_info_and_bias(). + * + * From a probabilistic perspective, the above calculation reduces to + * c = sum_k (a_k * b_k) + d + */ + + return suggest_mac_dst_q_info_and_bias(lhs_q_info, rhs_q_info, k, data_type, bias_fraction); +} + +QuantizationHint suggest_mac_dst_q_info_and_bias( + const QuantizationInfo &a_q_info, const QuantizationInfo &b_q_info, int32_t K, DataType data_type, float bias_fraction) +{ QuantizationInfo c_q_info; ARM_COMPUTE_ASSERT(data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED); @@ -360,21 +400,13 @@ QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, const int32_t t_max = static_cast(data_type == DataType::QASYMM8 ? std::numeric_limits::max() : std::numeric_limits::max()); const int32_t t_min = static_cast(data_type == DataType::QASYMM8 ? std::numeric_limits::min() : std::numeric_limits::min()); - /** Quantization Setup of matrix multiplication + /** Quantization Setup of multiply-accummulate * - * We have a matrix multiplication of the form C = A * B - * where A is (M X K), B is (K x N) and C is therefore (M x N). + * Expression (in float): + * C = sum_k ( A_k * B_k ) + D * - * If we have some distributions statistics of A and B, i.e. mean and variance, - * we can estimate the mean and variance of a single value in C matrix and - * pick good scale and offset values for the output and have non-saturated tests. - * - * Each element in the output matrix can be calculated as follows: - * C_ij = sum_k(A_ik * B_kj) - * - * All values are float above. - * - * Note: All possible A_ik, B_kj random variables are assumed mutually independent. + * Lemma: An affine transformation (i.e. aX + b) to a discrete uniform random variable + * creates another discrete uniform random variable. * * Terminology: * E[X]: Mean of the random variable X (sometimes referred as mu_x) @@ -382,26 +414,58 @@ QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, * std(X): sqrt(var(X)), standard deviation of X * * 1) Calculate the mean: - * E[C_ij] = sum_k( E[A_ik] * E[B_kj] ) = K * mean_a * mean_b + * E[C] = sum_k( E[A_k] * E[B_k] ) + D = K * mean_a * mean_b + mean_d * * Since elements of A and B are uniformly distributed random variables, we have * mean_a = (max_a + min_a) / 2, mean_b = (max_b + min_b ) / 2 * max_a and min_a can be calculated with the scale_a/b and offset_a/b * by replacing data type minimum and maximums in the equations * + * We don't know mean_d because we have to choose it based on bias_fraction. If we call + * the summation as M_int, similar to above, we have: + * + * E[C_int] = sum_k( E[A_k_int] * E[B_k_int] ) + E[D_int] = K * mean_a_int * mean_b_int + mean_d_int + * \___________________________/ + * E[M_int] + * + * We choose a bias mean proportional to the integer summation. This proportion is "bias_fraction". + * So, we have D_int = f * M_int (f: fraction), and + * E[D_int] = mean_d_int = f * E[M_int] + * + * This also means, for floating point value of D, the following: + * E[D] = mean_d = E[D_int] * a_scale * b_scale + * * 2) Calculate the variance: - * var(C_ij) = sum_k( var(A_ik * B_kj) ) - * = sum_k ( E[A_ik^2 * B_kj^2] - E[A_ik]^2E[B_kj^2] ) + * var(C) = sum_k( var(A_k * B_k) ) + var(D) + * = sum_k ( E[A_k^2 * B_k^2] - E[A_k]^2E[B_k^2] ) * = ... - * = K * (var_a * var_b + var_a * mean^2_b + var_b * mean^2_a) + * = K * (var_a * var_b + var_a * mean^2_b + var_b * mean^2_a) + var_d * * Similarly, due to uniform random variable properties, we have * var_a = (max_a - min_a)^2 / 12 * var_b = (max_b - min_b)^2 / 12 * + * Again, we don't know var_d as we don't know the bias. As set out in the previous section, we have + * var(D_int) = var(f * M_int) = f^2 * var(M_int) + * + * Using the same expression, we can find var(M_int): + * var(C_int) = sum_k( var(A_k_int * B_k_int) ) + var(D_int) + * = sum_k ( E[A_k_int^2 * B_k_int^2] - E[A_k_int]^2E[B_k_int^2] ) + * = ... + * = K * (var_a_int * var_b_int + var_a_int * mean^2_b_int + var_b_int * mean^2_a_int) + var_d_int + * \_______________________________________________________________________________/ + * var(M_int) + * + * Now, we know mean and variance of D_int, we can return a suitable bias range with + * [mean_d_int +/- 2 * std_d_int] + * + * This also means, for floating point value of D, the following: + * var(D) = var_d = var(D_int) * a_scale^2 * b_scale^2 * - * 3) Now, we have an idea of what would an average C_ij will like and how much deviation - * is present around it. The exact distribution of C is not easy to come up with dependent on K. + * E[D] and var(D) calculated in steps (1) and (2) can be substituted into E[C] and var(C) calculatons. + * + * 3) Now, we have an idea of what would an average C will look like and how much deviation + * is present around it. The exact distribution of C is difficult to come up with dependent on K. * But, as K increases, due to Central Limit Theorem, it'll look more like a bell shaped figure, * approaching normal distribution. * @@ -424,6 +488,12 @@ QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, const int32_t b_offset = b_q_info.uniform().offset; const float b_scale = b_q_info.uniform().scale; + // Integer value statistics. Valid for both Lhs/A and Rhs/B + const float mean_a_int = (t_max + t_min) / 2.f; + constexpr float var_a_int = (256 * 256 - 1) / 12.f; // Discrete uniform RV variance + const float mean_b_int = mean_a_int; // A_int and B_int has the same stats + constexpr float var_b_int = var_a_int; + // Lhs/A stats const float max_a = (t_max - a_offset) * a_scale; const float min_a = (t_min - a_offset) * a_scale; @@ -436,9 +506,25 @@ QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, const float mean_b = (max_b + min_b) / 2; const float var_b = (max_b - min_b) * (max_b - min_b) / 12; - // Output stats - const float mean_out = k * mean_a * mean_b; - const float var_out = k * (var_a * var_b + var_a * mean_b * mean_b + var_b * mean_a * mean_a); + // Integer multiplication output/M stats + const float mean_m_int = K * mean_a_int * mean_b_int; + const float var_m_int = K * (var_a_int * var_b_int + mean_a_int * var_b_int + mean_b_int + var_a_int); + const float std_m_int = sqrt(var_m_int); + + // Bias/D both Int and Float statistics + const float mean_d_int = bias_fraction * mean_m_int; + const float std_d_int = bias_fraction * std_m_int; + const float mean_d = a_scale * b_scale * mean_d_int; + const float std_d = a_scale * b_scale * std_d_int; + const float var_d = std_d * std_d; + + // Also calculate the suggested bias range + const int32_t min_bias = mean_d_int - 2 * std_d_int; + const int32_t max_bias = mean_d_int + 2 * std_d_int; + + // Output/C stats + const float mean_out = K * mean_a * mean_b + mean_d; + const float var_out = K * (var_a * var_b + var_a * mean_b * mean_b + var_b * mean_a * mean_a) + var_d; const float std_out = sqrt(var_out); // Output quantization setup @@ -446,7 +532,8 @@ QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &a_q_info, const int32_t offset_out = static_cast(t_min - (mean_out - 2.f * std_out) / scale_out); c_q_info = QuantizationInfo(scale_out, offset_out); - return c_q_info; + + return { c_q_info, min_bias, max_bias }; } template void get_tile(const SimpleTensor &in, SimpleTensor &roi, const Coordinates &coord); diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h index c9d990d3a2..7d53c1de37 100644 --- a/tests/validation/Helpers.h +++ b/tests/validation/Helpers.h @@ -31,7 +31,8 @@ #include "tests/Globals.h" #include "tests/SimpleTensor.h" -#include +#include +#include #include #include #include @@ -52,6 +53,19 @@ struct is_floating_point : public std::true_type { }; +/** Helper struct to store the hints for + * - destination quantization info + * - minimum bias value + * - maximum bias value + * in quantized test construction. + */ +struct QuantizationHint +{ + QuantizationInfo q_info; + int32_t bias_min; + int32_t bias_max; +}; + /** Helper function to get the testing range for each activation layer. * * @param[in] activation Activation function to test. @@ -226,10 +240,36 @@ std::pair get_symm_quantized_per_channel_bounds(const QuantizationInfo */ void add_padding_x(std::initializer_list tensors, const DataLayout &data_layout = DataLayout::NHWC, bool only_right_pad = false); -/** For MatMulLowp, given the Lhs/Rhs matrix quantization informations and the matrix multiplication dimensions, - * calculate a suitable output quantization for obtaining non-saturated outputs with high probability. +/** For a matrix multiplication, given the Lhs/Rhs matrix quantization informations and the matrix multiplication dimensions, + * calculate a suitable output quantization and suggested bias range for obtaining non-saturated outputs with high probability. + * + * @param[in] lhs_q_info Lhs matrix quantization info + * @param[in] rhs_q_info Rhs matrix quantization info + * @param[in] m Number of rows of Lhs matrix + * @param[in] n Number of columns of Rhs Matrix + * @param[in] k Number of rows/columns of Rhs/Lhs Matrix + * @param[in] data_type data type, only QASYMM8, QASYMM8_SIGNED are supported + * @param[in] bias_fraction the fraction of bias amplitude compared to integer accummulation. 0 if there is no bias. + * + * @return QuantizationHint object containing the suggested output quantization info and min/max bias range + */ +QuantizationHint suggest_matmul_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info, + const QuantizationInfo &rhs_q_info, int32_t m, int32_t n, int32_t k, DataType data_type, + float bias_fraction); + +/** For a multiply-accumulate (mac), given the Lhs/Rhs vector quantization informations and the dot product dimensions, + * calculate a suitable output quantization and suggested bias range for obtaining non-saturated outputs with high probability. + * + * @param[in] lhs_q_info Lhs matrix quantization info + * @param[in] rhs_q_info Rhs matrix quantization info + * @param[in] k number of accumulations taking place in the sum, i.e. c_k = sum_k(a_k * b_k) + * @param[in] data_type data type, only QASYMM8, QASYMM8_SIGNED are supported + * @param[in] bias_fraction the fraction of bias amplitude compared to integer accummulation. + * + * @return QuantizationHint object containing the suggested output quantization info and min/max bias range */ -QuantizationInfo calculate_mat_mul_dst_q_info(const QuantizationInfo &lhs_q_info, const QuantizationInfo &rhs_q_info, int m, int n, int k, DataType data_type); +QuantizationHint suggest_mac_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info, + const QuantizationInfo &rhs_q_info, int32_t k, DataType data_type, float bias_fraction); } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/fixtures/MatMulKernelFixture.h b/tests/validation/fixtures/MatMulKernelFixture.h index 91ac77d5af..50d194c43a 100644 --- a/tests/validation/fixtures/MatMulKernelFixture.h +++ b/tests/validation/fixtures/MatMulKernelFixture.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE -#define ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE +#ifndef ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE_H #include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Utils.h" @@ -54,6 +54,12 @@ public: void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type, bool enable_bias) { + // This hash is used by random generators. There may be hash collisions but + // this is intentional as it's a very easy way to make the the current + // random generation process almost different for many test configurations, + // which were using the same set of values before. + _hash = M0 + N0 + K0 + shape_a[0] + shape_a[1] + shape_b[0] + shape_b[1] + enable_bias + export_rhs_to_cl_image; + // Flag to create a bias _enable_bias = enable_bias; @@ -67,7 +73,7 @@ public: const int32_t t_max = static_cast(std::numeric_limits::max()); const int32_t t_min = static_cast(std::numeric_limits::min()); - std::mt19937 generator(library->seed()); + std::mt19937 generator(library->seed() + _hash); std::uniform_real_distribution distribution_float(-5.0f, 3.0f); std::uniform_int_distribution distribution_t(t_min, t_max); @@ -84,7 +90,12 @@ public: const int n = shape_b.x(); const int k = shape_a.x(); - dst_q_info = calculate_mat_mul_dst_q_info(lhs_q_info, rhs_q_info, m, n, k, data_type); + const float bias_fraction = enable_bias ? 0.5f : 0.f; + + QuantizationHint q_hint = suggest_matmul_dst_q_info_and_bias(lhs_q_info, rhs_q_info, m, n, k, data_type, bias_fraction); + dst_q_info = q_hint.q_info; + _min_bias = q_hint.bias_min; + _max_bias = q_hint.bias_max; } if(pretranspose_a) @@ -142,12 +153,9 @@ protected: } template - void fill_bias_s32(U &&tensor, int i, const UniformQuantizationInfo &q_info) + void fill_bias_s32(U &&tensor, int i, int32_t min, int32_t max) { - // For quantized cases, fill the S32 bias according to the following to avoid saturation of test cases. - // The following code limits size of bias values to within expected range of output quantization. - const unsigned int bound = std::abs(q_info.scale * 256); // 256 is size of 8 bit datatype - std::uniform_int_distribution distribution(-(bound / 10), (bound / 10)); + std::uniform_int_distribution distribution(min, max); library->fill(tensor, distribution, i); } @@ -192,8 +200,8 @@ protected: ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Fill tensors - fill(CLAccessor(a), 0); - fill(CLAccessor(b), 1); + fill(CLAccessor(a), _hash + 1); + fill(CLAccessor(b), _hash + 2); // Compute matMul kernel ITensorPack tensors_pack({ { ACL_SRC_0, &a }, @@ -207,11 +215,11 @@ protected: bias.allocator()->allocate(); if(is_quantized) { - fill_bias_s32(CLAccessor(bias), 2, dst_q_info.uniform()); + fill_bias_s32(CLAccessor(bias), _hash + 3, _min_bias, _max_bias); } else { - fill(CLAccessor(bias), 2); + fill(CLAccessor(bias), _hash + 3); } tensors_pack.add_tensor(ACL_SRC_2, &bias); } @@ -236,8 +244,8 @@ protected: SimpleTensor c{ output_shape_collapsed, data_type, 1, dst_q_info }; // Fill reference - fill(a, 0); - fill(b, 1); + fill(a, _hash + 1); + fill(b, _hash + 2); /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M), therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K) @@ -288,7 +296,7 @@ protected: // of bias tensor from shape [dst.dimension(0)] to [dst.tensor_shape()] in target kernel if(_enable_bias) { - fill(c, 2); + fill(c, _hash + 3); const int n = c.shape().x(); const int other_dims = c.shape().collapsed_from(1)[1]; for(int i = 1; i < other_dims; ++i) // For all data, copy first n elements into remaining batches @@ -323,7 +331,7 @@ protected: if(_enable_bias) { // Identical to float implementation, fill and copy values of bias first dimension - fill_bias_s32(bias, 2, cq); + fill_bias_s32(bias, _hash + 3, _min_bias, _max_bias); const int n = bias.shape().x(); const int other_dims = bias.shape().collapsed_from(1)[1]; const unsigned int dt_size = sizeof(int32_t); @@ -348,6 +356,9 @@ protected: bool _enable_bias{ false }; bool _device_supports_export_to_cl_image{ true }; bool _device_supports_mmul{ true }; + int32_t _min_bias{ 0 }; + int32_t _max_bias{ 0 }; + int32_t _hash{ 0 }; }; template @@ -374,4 +385,4 @@ public: } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE_H -- cgit v1.2.1