diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp | 226 |
1 files changed, 141 insertions, 85 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp index e1af2d4490..9409646818 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp @@ -28,13 +28,35 @@ #include <arm_neon.h> template<> -inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) { +void MergeResults<8, 6, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float *bias, Activation act, bool append) { const float *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 96); - float32x4_t av = vdupq_n_f32(alpha); - float32x4_t bv = vdupq_n_f32(beta); + float nullbias[8]; + float minval = - std::numeric_limits<float>::infinity(); + float maxval = std::numeric_limits<float>::infinity(); + + switch(act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0.0f; + break; + } + + float32x4_t minv = vdupq_n_f32(minval); + float32x4_t maxv = vdupq_n_f32(maxval); + + if (!append && !bias) + { + memset(nullbias, 0, (8 * sizeof(float))); + } for (int y=y0; y<ymax; y+=8) { float *outptr0 = out + (y * ldout) + x0; @@ -61,16 +83,12 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo switch ((y + 5) - ymax) { case 4: outptr1 = dummyres; - // fall through case 3: outptr2 = dummyres; - // fall through case 2: outptr3 = dummyres; - // fall through case 1: outptr4 = dummyres; - // fall through case 0: outptr5 = dummyres; break; @@ -80,24 +98,24 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo } } - if (beta == 0.0f) { - /* If beta=0, don't read the original input at all. */ + if (append) { + /* Append mode: Read, activate, write. */ /* For ragged X, manually copy over the valid results. */ if ((i+7) >= xmax) { for (int xi=0; xi<8; xi++) { if ((i+xi) < xmax) { - *outptr0 = (alpha * inptr[xi]); + *outptr0 = std::min(std::max(minval, inptr[xi] + *outptr0), maxval); outptr0++; - *outptr1 = (alpha * inptr[xi + 8]); + *outptr1 = std::min(std::max(minval, inptr[xi + 8] + *outptr1), maxval); outptr1++; - *outptr2 = (alpha * inptr[xi + 16]); + *outptr2 = std::min(std::max(minval, inptr[xi + 16] + *outptr2), maxval); outptr2++; - *outptr3 = (alpha * inptr[xi + 24]); + *outptr3 = std::min(std::max(minval, inptr[xi + 24] + *outptr3), maxval); outptr3++; - *outptr4 = (alpha * inptr[xi + 32]); + *outptr4 = std::min(std::max(minval, inptr[xi + 32] + *outptr4), maxval); outptr4++; - *outptr5 = (alpha * inptr[xi + 40]); + *outptr5 = std::min(std::max(minval, inptr[xi + 40] + *outptr5), maxval); outptr5++; } } @@ -107,69 +125,100 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo __asm __volatile ( // Rows 0-1 "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d8-d11}, [%[outptr0]]\n" "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VLD1.32 {d12-d15}, [%[outptr1]]\n" - "VMUL.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[inptr], #352]") - "VMUL.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr0]]!\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q2\n" + "VADD.f32 q7, q7, q3\n" ASM_PREFETCH("[%[inptr], #416]") - "VMUL.f32 q6, q2, %q[av]\n" + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" ASM_PREFETCH("[%[inptr], #480]") - "VMUL.f32 q7, q3, %q[av]\n" + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr0]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr1]]!\n" // Rows 2-3 "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d8-d11}, [%[outptr2]]\n" "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VLD1.32 {d12-d15}, [%[outptr3]]\n" - "VMUL.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[outptr0], #96]") - "VMUL.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr2]]!\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q2\n" + "VADD.f32 q7, q7, q3\n" ASM_PREFETCH("[%[outptr1], #96]") - "VMUL.f32 q6, q2, %q[av]\n" - ASM_PREFETCH("[%[outptr2], #96]") - "VMUL.f32 q7, q3, %q[av]\n" + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" + ASM_PREFETCH("[%[outptr2], #128]") + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr2]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr3]]!\n" // Rows 4-5 "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d8-d11}, [%[outptr4]]\n" "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VLD1.32 {d12-d15}, [%[outptr5]]\n" - "VMUL.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[outptr3], #96]") - "VMUL.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr4]]!\n" - ASM_PREFETCH("[%[outptr4], #96]") - "VMUL.f32 q6, q2, %q[av]\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q2\n" + "VADD.f32 q7, q7, q3\n" + ASM_PREFETCH("[%[outptr4], #128]") + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" ASM_PREFETCH("[%[outptr5], #128]") - "VMUL.f32 q7, q3, %q[av]\n" + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr4]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr5]]!\n" : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr) - : [av] "w" (av), [bv] "w" (bv) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + : [minv] "w" (minv), [maxv] "w" (maxv) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "memory" ); } } else { - /* Non-zero beta: Read output and apply beta. */ + /* Bias mode: Add bias to everything, then min/max/write as before. */ + const float *biasptr = bias ? bias + i : nullbias; /* For ragged X, manually copy over the valid results. */ if ((i+7) >= xmax) { - for (int xi=0; xi<8; xi++) { + for (int xi=0; xi<7; xi++) { if ((i+xi) < xmax) { - *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); + *outptr0 = std::min(std::max(minval, inptr[xi] + biasptr[xi]), maxval); outptr0++; - *outptr1 = (alpha * inptr[xi + 8]) + (*outptr1 * beta); + *outptr1 = std::min(std::max(minval, inptr[xi + 8] + biasptr[xi]), maxval); outptr1++; - *outptr2 = (alpha * inptr[xi + 16]) + (*outptr2 * beta); + *outptr2 = std::min(std::max(minval, inptr[xi + 16] + biasptr[xi]), maxval); outptr2++; - *outptr3 = (alpha * inptr[xi + 24]) + (*outptr3 * beta); + *outptr3 = std::min(std::max(minval, inptr[xi + 24] + biasptr[xi]), maxval); outptr3++; - *outptr4 = (alpha * inptr[xi + 32]) + (*outptr4 * beta); + *outptr4 = std::min(std::max(minval, inptr[xi + 32] + biasptr[xi]), maxval); outptr4++; - *outptr5 = (alpha * inptr[xi + 40]) + (*outptr5 * beta); + *outptr5 = std::min(std::max(minval, inptr[xi + 40] + biasptr[xi]), maxval); outptr5++; } } @@ -178,68 +227,75 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo /* Optimized routine to copy an entire block */ __asm __volatile ( // Rows 0-1 - "VLD1.32 {d8-d11}, [%[outptr0]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr1]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" + "VLD1.32 {d8-d11}, [%[inptr]]!\n" + "VLD1.32 {d0-d3}, [%[biasptr]]\n" + "VLD1.32 {d12-d15}, [%[inptr]]!\n" - "VMLA.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[inptr], #352]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr0]]!\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q0\n" + "VADD.f32 q7, q7, q1\n" ASM_PREFETCH("[%[inptr], #416]") - "VMLA.f32 q6, q2, %q[av]\n" + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" ASM_PREFETCH("[%[inptr], #480]") - "VMLA.f32 q7, q3, %q[av]\n" + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr0]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr1]]!\n" // Rows 2-3 - "VLD1.32 {d8-d11}, [%[outptr2]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr3]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" + "VLD1.32 {d8-d11}, [%[inptr]]!\n" + "VLD1.32 {d12-d15}, [%[inptr]]!\n" - "VMLA.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[outptr0], #96]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr2]]!\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q0\n" + "VADD.f32 q7, q7, q1\n" ASM_PREFETCH("[%[outptr1], #96]") - "VMLA.f32 q6, q2, %q[av]\n" - ASM_PREFETCH("[%[outptr2], #96]") - "VMLA.f32 q7, q3, %q[av]\n" + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" + ASM_PREFETCH("[%[outptr2], #128]") + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr2]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr3]]!\n" // Rows 4-5 - "VLD1.32 {d8-d11}, [%[outptr4]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr5]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" + "VLD1.32 {d8-d11}, [%[inptr]]!\n" + "VLD1.32 {d12-d15}, [%[inptr]]!\n" - "VMLA.f32 q4, q0, %q[av]\n" + "VADD.f32 q4, q4, q0\n" ASM_PREFETCH("[%[outptr3], #96]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr4]]!\n" - ASM_PREFETCH("[%[outptr4], #96]") - "VMLA.f32 q6, q2, %q[av]\n" + "VADD.f32 q5, q5, q1\n" + "VADD.f32 q6, q6, q0\n" + "VADD.f32 q7, q7, q1\n" + ASM_PREFETCH("[%[outptr4], #128]") + "VMAX.f32 q4, q4, %q[minv]\n" + "VMAX.f32 q5, q5, %q[minv]\n" + "VMAX.f32 q6, q6, %q[minv]\n" ASM_PREFETCH("[%[outptr5], #128]") - "VMLA.f32 q7, q3, %q[av]\n" + "VMAX.f32 q7, q7, %q[minv]\n" + "VMIN.f32 q4, q4, %q[maxv]\n" + "VMIN.f32 q5, q5, %q[maxv]\n" + "VST1.32 {d8-d11}, [%[outptr4]]!\n" + "VMIN.f32 q6, q6, %q[maxv]\n" + "VMIN.f32 q7, q7, %q[maxv]\n" "VST1.32 {d12-d15}, [%[outptr5]]!\n" : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr) - : [av] "w" (av), [bv] "w" (bv) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + : [minv] "w" (minv), [maxv] "w" (maxv), [biasptr] "r" (biasptr) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "memory" ); } } |