aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp13
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp14
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp12
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp14
4 files changed, 46 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index a460fdfcf4..543664bb0e 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@ template<typename T>
inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
+ bool first = true;
uint32_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop
@@ -53,8 +54,9 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *
//prefetch_2x(inptr5);
int x=(kmax-k0);
- for (;x>7;x-=8) {
+ for (;(x>7) || first;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
+ /* 'first' forces this to always run at least once, needed if the total size is <=7. */
if ((y + 5) >= ymax) {
switch ((y + 5) - ymax) {
/* Everything falls through in here */
@@ -79,6 +81,13 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *
}
}
+ if (first) {
+ if (x<=7) {
+ break;
+ }
+
+ first = false;
+ }
__asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 0028ab08a9..80dd6c5e25 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@ template<typename T>
void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint16_t *outptr = (uint16_t *)out;
const uint16_t *inptr = (const uint16_t *)in;
+ bool first=true;
uint16_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop
@@ -57,8 +58,9 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int
prefetch_2x(inptr7);
int x=(kmax-k0);
- for (;x>7;x-=8) {
+ for (;(x>7) || first;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
+ /* 'first' forces this to always run at least once, needed if the total size is <=7. */
if ((y + 7) >= ymax) {
switch ((y + 7) - ymax) {
/* Everything falls through in here */
@@ -89,6 +91,14 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int
}
}
+ if (first) {
+ if (x <= 7) {
+ break;
+ }
+
+ first = false;
+ }
+
int skippf = (x & 31);
__asm __volatile (
// Load up 8 elements (1 vector) from each of 8 sources.
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 758c084a46..9dfc1346e6 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -34,6 +34,7 @@ template<typename T>
inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint32_t *outptr = (uint32_t *)out;
const uint32_t *inptr = (uint32_t *)in;
+ bool first = true;
uint32_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop
@@ -57,8 +58,9 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *
prefetch_2x(inptr7);
int x=(kmax-k0);
- for (;x>7;x-=8) {
+ for (;(x>7) || first;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
+ /* 'first' forces this to always run at least once, needed if the total size is <=7. */
if ((y + 7) >= ymax) {
switch ((y + 7) - ymax) {
/* Everything falls through in here */
@@ -89,6 +91,14 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *
}
}
+ if (first) {
+ if (x<=7) {
+ break;
+ }
+
+ first = false;
+ }
+
__asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
"LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index de8e95a6d7..bde3274926 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@ template<>
inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) {
float *outptr = out;
const __fp16 *inptr = in;
+ bool first = true;
__fp16 zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop
@@ -57,8 +58,9 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const
prefetch_2x(inptr7);
int x=(kmax-k0);
- for (;x>7;x-=8) {
+ for (;(x>7) || first;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
+ /* 'first' forces this to always run at least once, needed if the total size is <=7. */
if ((y + 7) >= ymax) {
switch ((y + 7) - ymax) {
/* Everything falls through in here */
@@ -89,6 +91,14 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const
}
}
+ if (first) {
+ if (x<=7) {
+ break;
+ }
+
+ first = false;
+ }
+
__asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
"LDR q0, [%[inptr0]], #16\n"