aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
diff options
context:
space:
mode:
authorAnthony Barbier <anthony.barbier@arm.com>2018-07-03 16:22:02 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commit5f707736413aeac77818c42838296966f8dc6761 (patch)
treeb829ed3243ea5f3085f288836132416c78bc2e72 /src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
parent7485d5a62685cb745ab50e970adb722cb71557ac (diff)
downloadComputeLibrary-5f707736413aeac77818c42838296966f8dc6761.tar.gz
COMPMID-1369: Revert accidental formatting of RSH's repo
Pulled latest fixes from David's repo: commit f43ebe932c84083332b0b1a0348241b69dda63a7 Author: David Mansell <David.Mansell@arm.com> Date: Tue Jul 3 18:09:01 2018 +0100 Whitespace tidying, fixed comment in gemv_batched imported from ACL. Change-Id: Ie37a623f44e90d88072236cb853ac55ac82d5f51 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/138530 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: David Mansell <david.mansell@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp218
1 files changed, 99 insertions, 119 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
index 3218ca1aac..63e85c155a 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
@@ -24,137 +24,117 @@
#pragma once
template <unsigned int IntBy, typename TIn, typename TOut>
-struct TransposeInterleaveCommon
-{
- // Override the moveblock_1xY methods to improve performance
- static inline void moveblock_1x1(const TIn *&in0, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
- }
+struct TransposeInterleaveCommon {
+ // Override the moveblock_1xY methods to improve performance
+ static inline void moveblock_1x1(const TIn *&in0, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
}
+ }
- static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in1++);
- }
+ static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in1++);
}
+ }
- static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in1++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in2++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in3++);
+ static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in1++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in2++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in3++);
+ }
+ }
+
+ static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) {
+ const auto ldin = stride;
+
+ TOut *outarray = out;
+ const TIn *inarray = in;
+ TOut *outptr_base = outarray;
+ const TIn *inptr_base = inarray + x0 + (k0 * ldin);
+ int ldout = (kmax - k0) * IntBy;
+
+ int k=(kmax-k0);
+ for ( ; k>3; k-=4) {
+ TOut *outptr = outptr_base;
+ const TIn *inptr = inptr_base;
+ const TIn *inptr1 = inptr + ldin;
+ const TIn *inptr2 = inptr1 + ldin;
+ const TIn *inptr3 = inptr2 + ldin;
+
+ prefetch_3x(inptr);
+ prefetch_3x(inptr1);
+ prefetch_3x(inptr2);
+ prefetch_3x(inptr3);
+
+ outptr_base += IntBy * 4;
+ inptr_base += ldin * 4;
+
+ for (int x = (xmax-x0) / IntBy; x > 0 ; x--) {
+ moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr);
+ outptr += ldout;
}
}
- static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax)
- {
- const auto ldin = stride;
-
- TOut *outarray = out;
- const TIn *inarray = in;
- TOut *outptr_base = outarray;
- const TIn *inptr_base = inarray + x0 + (k0 * ldin);
- int ldout = (kmax - k0) * IntBy;
-
- int k = (kmax - k0);
- for(; k > 3; k -= 4)
- {
- TOut *outptr = outptr_base;
- const TIn *inptr = inptr_base;
- const TIn *inptr1 = inptr + ldin;
- const TIn *inptr2 = inptr1 + ldin;
- const TIn *inptr3 = inptr2 + ldin;
-
- prefetch_3x(inptr);
- prefetch_3x(inptr1);
- prefetch_3x(inptr2);
- prefetch_3x(inptr3);
-
- outptr_base += IntBy * 4;
- inptr_base += ldin * 4;
-
- for(int x = (xmax - x0) / IntBy; x > 0; x--)
- {
- moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr);
- outptr += ldout;
+ if (k) {
+ TOut *outptr = outptr_base;
+ const TIn *inptr = inptr_base;
+ const TIn *inptr1 = inptr + ldin;
+ const TIn *inptr2 = inptr1 + ldin;
+
+ prefetch_3x(inptr);
+ prefetch_3x(inptr1);
+ prefetch_3x(inptr2);
+
+ for (int x = (xmax-x0) / IntBy; x > 0 ; x--) {
+ switch(k) {
+ case 3:
+ moveblock_1x2(inptr, inptr1, outptr);
+ moveblock_1x1(inptr2, outptr + IntBy * 2);
+ break;
+
+ case 2:
+ moveblock_1x2(inptr, inptr1, outptr);
+ break;
+
+ case 1:
+ moveblock_1x1(inptr, outptr);
+ break;
+
+ default:
+ UNREACHABLE("Impossible.");
}
- }
- if(k)
- {
- TOut *outptr = outptr_base;
- const TIn *inptr = inptr_base;
- const TIn *inptr1 = inptr + ldin;
- const TIn *inptr2 = inptr1 + ldin;
-
- prefetch_3x(inptr);
- prefetch_3x(inptr1);
- prefetch_3x(inptr2);
-
- for(int x = (xmax - x0) / IntBy; x > 0; x--)
- {
- switch(k)
- {
- case 3:
- moveblock_1x2(inptr, inptr1, outptr);
- moveblock_1x1(inptr2, outptr + IntBy * 2);
- break;
-
- case 2:
- moveblock_1x2(inptr, inptr1, outptr);
- break;
-
- case 1:
- moveblock_1x1(inptr, outptr);
- break;
-
- default:
- UNREACHABLE("Impossible.");
- }
-
- outptr += ldout;
- }
+ outptr += ldout;
}
+ }
+
+ // Cope with ragged X cases
+ const unsigned int overflow = (xmax - x0) % IntBy;
+ if (overflow) {
+ const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin);
+ TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout;
+
+ for (int k=(kmax-k0); k>0; k--) {
+ const TIn *inptr = inptr_base;
+ inptr_base += ldin;
- // Cope with ragged X cases
- const unsigned int overflow = (xmax - x0) % IntBy;
- if(overflow)
- {
- const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin);
- TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout;
-
- for(int k = (kmax - k0); k > 0; k--)
- {
- const TIn *inptr = inptr_base;
- inptr_base += ldin;
-
- for(unsigned int x = 0; x < IntBy; x++)
- {
- TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0);
- *outptr++ = val;
- }
+ for (unsigned int x=0; x < IntBy; x++) {
+ TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0);
+ *outptr++ = val;
}
}
}
+}
};