From a668f9f8a4eab405df0fe8dd58e7d9425bcf9640 Mon Sep 17 00:00:00 2001 From: Jonathan Deakin Date: Wed, 24 Jan 2024 09:15:38 +0000 Subject: Add s8f32 kernels and dynamic QuantizationInfo - Add support for QASYMM_SIGNED*QASYMM8_SIGNED->F32 in CpuGemmLowpMatrixMultiplyCore - Add s8f32 kernel using existing s8->s32 kernels with a new DequantizeFloat OutputStage, the structure is similar to Requantize32 but the opposite way around. - Add SME s8f32 kernels with integrated support for DequantizeFloat. - Add scale to CpuGemmLowpOffsetContributionKernel. - Add virtual dequantize scale to gemm_common, only implemented for gemm_interleaved. - Update year to 2024 in generate_build_files. - Add dynamic flag to QuantizationInfo which signals to operators that it can change after configuration - Add support for dynamic quantization in NEGEMMLowpMatrixMultiplyCore - Add dynamic quantization fixture by extending GEMMLowpGenericMatrixMultiplyCoreValidationFixture - Add GEMMLowpDequantizedMatrixMultiplyValidationFixture - Store k (number of cols of A) rather than k_offset in the offset contribution kernels so that we can recompute it when the other offsets change relates to: ONCPUML-1444 MLINFSW-439 Co-authored-by: Milos Puzovic Co-authored-by: David Mansell Change-Id: I58a3acf2c09289a303e52eea6b336a696a5bc8da Signed-off-by: Jonathan Deakin Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11022 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/quantized.cpp | 60 +++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index 111d01ed3a..6da9f4be0e 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1142,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride, + const float* bias_ptr, bool accumulate, const Activation &act) +{ + const float32x4_t vscale = vdupq_n_f32(qp.scale); + float maxval = std::numeric_limits::infinity(); + float minval = -std::numeric_limits::infinity(); + + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0; + break; + } + + const float32x4_t vmin = vdupq_n_f32(minval); + const float32x4_t vmax = vdupq_n_f32(maxval); + + for(unsigned int row=0; row= 4) { + for(; col <= (width - 4); col+= 4) { + const int32x4_t vin = vld1q_s32(row_in_ptr + col); + float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale); + if(bias_ptr) { + const float32x4_t bin = vld1q_f32(bias_ptr + col); + vdeq = vaddq_f32(vdeq, bin); + } + if(accumulate) { + vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col)); + } + vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax); + vst1q_f32(reinterpret_cast(row_out_ptr + col), vdeq); + } + } + // left-over elements + for(; col < width; ++col) { + const int32_t val = *(row_in_ptr + col); + float res = static_cast(val * qp.scale); + if(bias_ptr) { + res += static_cast(*(bias_ptr + col)); + } + if(accumulate) { + res += *(row_out_ptr + col); + } + res = std::min(std::max(res, minval), maxval); + *(row_out_ptr + col) = res; + } + } +} + } // namespace arm_gemm #endif // __aarch64__ -- cgit v1.2.1