diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index fdb4f584d8..d35825c428 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -179,13 +179,18 @@ public: return _subgemm->get_B_pretransposed_array_size() + col_sum_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { - uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer); - _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride); + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + _col_sums = reinterpret_cast<int32_t *>(in_buffer); + col_sums_pretransposed(B, ldb, B_multi_stride); + } - _col_sums = reinterpret_cast<int32_t *>(buffer); + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); - col_sums_pretransposed(B, ldb, B_multi_stride); + uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer); + _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed); + + requantize_bias(buffer, B, ldb, B_multi_stride); } void set_pretransposed_B_data(void *buffer) override { @@ -198,6 +203,19 @@ public: _params.bias = bias; _params.bias_multi_stride = bias_multi_stride; } + + GemmConfig get_config() override { + GemmConfig c = _subgemm->get_config(); + + std::string n = "quantize_wrapper["; + n.append(c.filter); + n.append("]"); + + c.method = GemmMethod::QUANTIZE_WRAPPER; + c.filter = n; + + return c; + } }; } // namespace arm_gemm |