aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/interleave-8way.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/interleave-8way.cpp267
1 files changed, 267 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
new file mode 100644
index 0000000000..a05d700c5e
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
@@ -0,0 +1,267 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include <arm_neon.h>
+
+#if !defined(_WIN64) && !defined(__OpenBSD__)
+#include <alloca.h>
+#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */
+
+#include <cstring>
+
+#include "transform.hpp"
+#include "utils.hpp"
+
+namespace arm_gemm {
+
+namespace {
+
+// Helper function to interleave a single 4x4 block of 32-bin values
+// together.
+
+// _full version doesn't need to worry about any padding.
+static inline void transpose_block_32_full(const uint8_t * __restrict in_ptr0, const uint8_t * __restrict in_ptr1, const uint8_t * __restrict in_ptr2, const uint8_t * __restrict in_ptr3, uint8_t * __restrict out_ptr, long output_stride) {
+ uint32x4_t inputs[4];
+ uint32x4_t inters[4];
+ uint32x4_t outputs[4];
+
+ inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr0));
+ inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr1));
+ inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr2));
+ inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr3));
+
+ inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+ inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+ inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+ inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+ outputs[0] = vzip1q_u32(inters[0], inters[2]);
+ outputs[1] = vzip2q_u32(inters[0], inters[2]);
+ outputs[2] = vzip1q_u32(inters[1], inters[3]);
+ outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+}
+
+// _part version: Only read "bytes_in" bytes, not a full vector. Only write
+// out 4-byte blocks that have some live content (if bytes_in is not a
+// multiple of 4 there will some padding in each 4-block)
+static inline void transpose_block_32_part(const uint8_t *in_ptr0, const uint8_t *in_ptr1, const uint8_t *in_ptr2, const uint8_t *in_ptr3, uint8_t *out_ptr, long bytes_in, long output_stride) {
+ uint32x4_t inputs[4];
+ uint32x4_t inters[4];
+ uint32x4_t outputs[4];
+ uint8_t scratch[16] = {0};
+
+ long num_outs = iceildiv<long>(bytes_in, 4);
+
+ memcpy(scratch, in_ptr0, bytes_in);
+ inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr1, bytes_in);
+ inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr2, bytes_in);
+ inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr3, bytes_in);
+ inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+
+ inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+ inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+ inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+ inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+ outputs[0] = vzip1q_u32(inters[0], inters[2]);
+ outputs[1] = vzip2q_u32(inters[0], inters[2]);
+ outputs[2] = vzip1q_u32(inters[1], inters[3]);
+ outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+ do {
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+ if (num_outs < 2)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+ if (num_outs < 3)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+ if (num_outs < 4)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+ } while (0);
+}
+
+template<unsigned N>
+struct Unroll {
+ template<typename F>
+ static void run(F f) {
+ Unroll<N-1>::run(f);
+ f(N-1);
+ }
+};
+
+template<>
+struct Unroll<0> {
+ template<typename F>
+ static void run(F) {
+ }
+};
+
+// Interleave some multiple of 4 rows together.
+//
+// The template parameter BLOCKS controls the size of the inner loop - each BLOCK is 4 rows.
+// The function parameter interleave_multiple controls the number of times the inner loop is run.
+
+// The total interleave depth for a given run is therefore BLOCKS * interleave_multiple * 4.
+template<unsigned BLOCKS>
+void a64_interleave_1x4(uint8_t *out, const uint8_t *in, long width, long in_stride, long height, long interleave_multiple) {
+ const long total_interleave_depth = BLOCKS * 4 * interleave_multiple;
+ constexpr long loop_interleave_depth = BLOCKS * 4;
+
+ uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width));
+
+ if (height % total_interleave_depth) {
+ memset(pad_row, 0, width);
+ }
+
+ // Outer loop: process blocks of total_interleave_depth rows at a time.
+ for (long y0_base=0; y0_base<height; y0_base+=total_interleave_depth) {
+ // Middle loop: process each "interlave_multiple" block of rows.
+ for (long block=0; block<interleave_multiple; block++) {
+ const long y0 = y0_base + (block * loop_interleave_depth);
+ uint8_t *out_ptr = out + (block * loop_interleave_depth * 4); // 4 is the blocking depth (we interleave 4 bytes at a time from each input)
+
+ // Create and set up input row pointers. The idea is that these
+ // should entirely fit in the register file, so we don't have to
+ // repeatedly load them (or perform the padding check)
+ const uint8_t *in_ptrs[loop_interleave_depth];
+ Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+ in_ptrs[y] = (y+y0 < height) ? in + ((y+y0) * in_stride) : pad_row;
+ });
+
+ long bytes_left = width;
+ // Process full vectors using transpose_block_32_full()
+ while (bytes_left >= 16) { // 16 is the vector length in bytes
+ Unroll<BLOCKS>::run( [&](unsigned u) {
+ transpose_block_32_full(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
+ out_ptr + 16*u, total_interleave_depth * 4); // 4 is the blocking depth
+ });
+
+ Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+ in_ptrs[y] += 16; // 16 is the vector length in bytes
+ });
+
+ out_ptr += total_interleave_depth * 16; // 16 is the vector length in bytes
+ bytes_left -= 16; // 16 is the vector length in bytes
+ }
+
+ // Process any remaining bytes using transpose_block_32_part()
+ if (bytes_left) {
+ Unroll<BLOCKS>::run( [&](unsigned u) {
+ transpose_block_32_part(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
+ out_ptr + 16*u, bytes_left, total_interleave_depth * 4);
+ });
+ }
+ }
+
+ // Update "out" pointer for next set of total_interleave_depth rows
+ out += total_interleave_depth * roundup<long>(width, 4);
+ }
+}
+
+} // anonymous namespace
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+ uint8_t *out, const uint8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0),
+ stride,
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+ int8_t *out, const int8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0),
+ stride,
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<12, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<3>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<16, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<24, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<3>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 2
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // __aarch64__