aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp26
1 files changed, 18 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 939788ed8d..ad504f2664 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 ARM Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,17 +45,15 @@ public:
_subgemm = gemm<To,Tr>(newargs);
}
- void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ void set_arrays(const To *A, const int, const int A_batch_stride, const int A_multi_stride,
const To *B, const int ldb, const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ Tr *C, const int, const int C_batch_stride, const int C_multi_stride,
const Tr *bias, const int bias_multi_stride) override {
/* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */
_subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride,
B, ldb, B_multi_stride,
C, C_batch_stride, 0, C_multi_stride,
bias, bias_multi_stride);
- UNUSED(lda);
- UNUSED(ldc);
}
ndrange_t get_window_size() const override {
@@ -66,7 +64,7 @@ public:
_subgemm->set_nthreads(nthreads);
}
- void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) override {
+ void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) override {
_subgemm->execute(work_range, thread_locator, threadid);
}
@@ -90,13 +88,25 @@ public:
return _subgemm->get_B_pretransposed_array_size();
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
- _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride);
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride, transposed);
}
void set_pretransposed_B_data(void *buffer) override {
_subgemm->set_pretransposed_B_data(buffer);
}
+
+ GemmConfig get_config() override {
+ GemmConfig c = _subgemm->get_config();
+
+ std::string new_filter = "gemv_batched[";
+ new_filter.append(c.filter);
+ new_filter.append("]");
+
+ c.filter = new_filter;
+
+ return c;
+ }
};
} // namespace arm_gemm