aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp30
1 files changed, 24 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index fdb4f584d8..d35825c428 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -179,13 +179,18 @@ public:
return _subgemm->get_B_pretransposed_array_size() + col_sum_size();
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
- uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
- _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ _col_sums = reinterpret_cast<int32_t *>(in_buffer);
+ col_sums_pretransposed(B, ldb, B_multi_stride);
+ }
- _col_sums = reinterpret_cast<int32_t *>(buffer);
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
- col_sums_pretransposed(B, ldb, B_multi_stride);
+ uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
+ _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed);
+
+ requantize_bias(buffer, B, ldb, B_multi_stride);
}
void set_pretransposed_B_data(void *buffer) override {
@@ -198,6 +203,19 @@ public:
_params.bias = bias;
_params.bias_multi_stride = bias_multi_stride;
}
+
+ GemmConfig get_config() override {
+ GemmConfig c = _subgemm->get_config();
+
+ std::string n = "quantize_wrapper[";
+ n.append(c.filter);
+ n.append("]");
+
+ c.method = GemmMethod::QUANTIZE_WRAPPER;
+ c.filter = n;
+
+ return c;
+ }
};
} // namespace arm_gemm