aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp83
1 files changed, 31 insertions, 52 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index 18f030fec0..d35825c428 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 ARM Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,16 +61,9 @@ private:
}
/* Local working space: We need space for the subgemm output (above) and
- * the row sums. If the GEMM is not pretransposed we need to store the
- * column sums here too. */
+ * the row sums. */
size_t local_working_size() const {
- size_t sz = subgemm_output_size() + row_sum_size();
-
- if (_args._pretransposed_hint) {
- return sz;
- }
-
- return sz + col_sum_size();
+ return subgemm_output_size() + row_sum_size();
}
void set_child_arrays() {
@@ -90,15 +83,6 @@ private:
}
}
- void col_sums_runtime(unsigned int threadid) {
- unsigned int first_col = (threadid * _args._Nsize) / _args._maxthreads;
- unsigned int last_col = ((threadid + 1) * _args._Nsize) / _args._maxthreads;
-
- for (unsigned int multi=0; multi<_args._nmulti; multi++) {
- compute_col_sums(_params, (last_col - first_col), _args._Ksize, this->_Bptr + (multi * this->_B_multi_stride) + first_col, this->_ldb, _col_sums + (multi * _args._Nsize) + first_col, _args._Ksize, multi, first_col);
- }
- }
-
void requantize_runtime(unsigned int threadid) {
unsigned int first_row = (threadid * _args._Msize) / _args._maxthreads;
unsigned int last_row = ((threadid+1) * _args._Msize) / _args._maxthreads;
@@ -115,7 +99,7 @@ private:
_args._Nsize,
this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (first_row * this->_ldc), this->_ldc,
_row_sums + (multi * _args._nbatches * _args._Msize) + (batch * _args._Msize) + first_row,
- _col_sums + (multi * _args._Nsize));
+ _col_sums + (multi * _args._Nsize), 0);
}
}
}
@@ -126,16 +110,12 @@ public:
QuantizeWrapper operator=(const QuantizeWrapper &) = delete;
QuantizeWrapper(const GemmArgs &args, const Requantize32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) {
- GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._trA, args._trB, Activation(), args._maxthreads, args._pretransposed_hint, nullptr);
+ GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._Ksections, args._nbatches, args._nmulti, args._indirect_input, Activation(), args._maxthreads);
_subgemm = gemm<To, Tgemm>(newargs);
if (_subgemm == nullptr) {
return;
}
-
- if (!_subgemm->B_is_pretransposed()) {
- _args._pretransposed_hint = false;
- }
}
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
@@ -149,7 +129,7 @@ public:
}
ndrange_t get_window_size() const override {
- return _subgemm->get_window_size();
+ return { _subgemm->get_window_size() };
}
void set_nthreads(int nthreads) override {
@@ -158,12 +138,8 @@ public:
_args._maxthreads = nthreads;
}
- // Execute
- void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) override {
+ void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) override {
_subgemm->execute(work_range, thread_locator, threadid);
- if (!_args._pretransposed_hint) {
- col_sums_runtime(threadid);
- }
_barrier.arrive_and_wait();
@@ -178,7 +154,7 @@ public:
// ptr
// V
- // | subgemm output | row_sums | col_sums (if not pretransposed | subgemm working space |
+ // | subgemm output | row_sums | subgemm working space |
void set_working_space(void *space) override {
uintptr_t space_int = reinterpret_cast<uintptr_t>(space);
@@ -186,16 +162,13 @@ public:
_subgemm->set_working_space(reinterpret_cast<void *>(space_int + local_working_size()));
_row_sums = reinterpret_cast<int32_t *>(space_int + subgemm_output_size());
- if (!_args._pretransposed_hint) {
- _col_sums = reinterpret_cast<int32_t *>(space_int + subgemm_output_size() + row_sum_size());
- }
set_child_arrays();
}
bool B_is_pretransposed() const override {
/* We clear this flag if the subgemm isn't pretransposed, so just return its value */
- return _args._pretransposed_hint;
+ return _subgemm->B_is_pretransposed();
}
bool B_pretranspose_required() const override {
@@ -203,31 +176,24 @@ public:
}
size_t get_B_pretransposed_array_size() const override {
- if (_args._pretransposed_hint) {
- return _subgemm->get_B_pretransposed_array_size() + col_sum_size();
- }
+ return _subgemm->get_B_pretransposed_array_size() + col_sum_size();
+ }
- return 0;
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ _col_sums = reinterpret_cast<int32_t *>(in_buffer);
+ col_sums_pretransposed(B, ldb, B_multi_stride);
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
- if (!_args._pretransposed_hint) {
- return;
- }
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
- _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
+ _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed);
- _col_sums = reinterpret_cast<int32_t *>(buffer);
-
- col_sums_pretransposed(B, ldb, B_multi_stride);
+ requantize_bias(buffer, B, ldb, B_multi_stride);
}
void set_pretransposed_B_data(void *buffer) override {
- if (!_args._pretransposed_hint) {
- return;
- }
-
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
_subgemm->set_pretransposed_B_data(reinterpret_cast<void *>(buffer_int + col_sum_size()));
_col_sums = reinterpret_cast<int32_t *>(buffer);
@@ -237,6 +203,19 @@ public:
_params.bias = bias;
_params.bias_multi_stride = bias_multi_stride;
}
+
+ GemmConfig get_config() override {
+ GemmConfig c = _subgemm->get_config();
+
+ std::string n = "quantize_wrapper[";
+ n.append(c.filter);
+ n.append("]");
+
+ c.method = GemmMethod::QUANTIZE_WRAPPER;
+ c.filter = n;
+
+ return c;
+ }
};
} // namespace arm_gemm