aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/misc.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/misc.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/misc.cpp48
1 files changed, 46 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/misc.cpp b/src/core/NEON/kernels/arm_gemm/misc.cpp
index 229e6b56f9..87310d996d 100644
--- a/src/core/NEON/kernels/arm_gemm/misc.cpp
+++ b/src/core/NEON/kernels/arm_gemm/misc.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2018, 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,11 @@
#ifndef NO_MULTI_THREADING
#include <mutex>
#endif
+#include <cstdint>
+
+#include "arm_gemm.hpp"
+#include "kernel_weight_format.hpp"
+#include "utils.hpp"
namespace arm_gemm {
@@ -32,4 +37,43 @@ namespace arm_gemm {
std::mutex report_mutex;
#endif
-} // namespace arm_gemm \ No newline at end of file
+WeightFormat get_weight_format(const KernelWeightFormat kwf, size_t element_size) {
+ if (kwf==KernelWeightFormat::NON_FIXED) {
+ return WeightFormat::UNSPECIFIED;
+ }
+
+ uint32_t kwf_i = static_cast<uint32_t>(kwf);
+ uint32_t wf_i = 0;
+
+ const auto block_bytes = (kwf_i >> 8) & 0xf;
+ const auto vector_count = (kwf_i >> 12) & 0xf;
+
+ uint32_t vector_bytes;
+
+ // For fast mode BF16 kernels set the appropriate bit and override element size to 2.
+ if (kwf_i & 0x10) {
+ element_size = 2;
+ wf_i |= 0x10;
+ }
+
+#ifdef ARM_COMPUTE_ENABLE_SVE
+ // Get total bytes in vector output
+ if (kwf_i & 0x1) {
+ vector_bytes = vector_count * get_vector_length<uint8_t>();
+ } else {
+#else
+ if (1) {
+#endif
+ vector_bytes = vector_count * 16;
+ }
+
+ auto input_blocking = block_bytes / element_size;
+ auto output_blocking = vector_bytes / block_bytes;
+
+ wf_i |= (input_blocking << 20);
+ wf_i |= (output_blocking << 8);
+
+ return static_cast<WeightFormat>(wf_i);
+}
+
+} // namespace arm_gemm