aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h6
-rw-r--r--src/core/CL/cl_kernels/gemm.cl260
-rw-r--r--src/runtime/CL/functions/CLFullyConnectedLayer.cpp23
3 files changed, 219 insertions, 70 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 584266b824..67c0467f3a 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -27,11 +27,11 @@
#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
#include "arm_compute/core/CL/kernels/CLTransposeKernel.h"
#include "arm_compute/runtime/CL/CLMemoryGroup.h"
#include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/CL/functions/CLGEMM.h"
#include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
@@ -113,12 +113,12 @@ public:
private:
void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
- void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed = true);
+ void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
CLMemoryGroup _memory_group;
CLIm2ColKernel _im2col_kernel;
CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel;
- CLGEMMMatrixMultiplyKernel _mm_kernel;
+ CLGEMM _mm_gemm;
CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 4b1672ce7b..381130ea7f 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -342,9 +342,6 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
__global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
__global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
- // Compute end row address for matrix B
- __global float *src_end_addr_b = src_addr_b + COLS_B;
-
src_addr_a += offset_row_a;
src_addr_b += offset_row_b;
@@ -366,12 +363,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
float c32 = 0.0f;
float c33 = 0.0f;
- for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
+#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
+
+ int i = 0;
+ for(; i <= (int)(COLS_MTX_B - 4); i += 4)
{
// Load values from matrix A (interleaved) and matrix B (transposed)
float4 a0 = vload4(0, src_addr_a);
float4 b0 = vload4(0, src_addr_b);
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
c02 = fma(a0.s0, b0.s2, c02);
@@ -393,8 +396,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -417,8 +423,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -441,8 +450,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -465,12 +477,15 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
}
- for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
+ for(; i < (int)(COLS_MTX_B); ++i)
{
// Load values from matrix A (interleaved) and matrix B (transposed)
float4 a0 = vload4(0, src_addr_a);
float4 b0 = vload4(0, src_addr_b);
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
c02 = fma(a0.s0, b0.s2, c02);
@@ -1130,9 +1145,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
src_addr.s1 += get_global_id(2) * src1_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- // Address boundary for matrix A
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
// Initialize accumulators
float acc00 = 0.0f;
float acc01 = 0.0f;
@@ -1161,72 +1173,162 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// A and B src indices get incremented at the same time.
- for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
+ int i = 0;
+ for(; i <= ((int)COLS_A - 4); i += 4)
{
- // Load values from matrix A
- float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ // Load values from matrix A and matrix B
+ float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+ float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+ float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+ float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- // Load values from matrix B
- float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
+ float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0.s0, b0.s0, acc00);
- acc00 = fma(a0.s1, b1.s0, acc00);
acc01 = fma(a0.s0, b0.s1, acc01);
- acc01 = fma(a0.s1, b1.s1, acc01);
acc02 = fma(a0.s0, b0.s2, acc02);
- acc02 = fma(a0.s1, b1.s2, acc02);
- acc03 = fma(a0.s1, b1.s3, acc03);
acc03 = fma(a0.s0, b0.s3, acc03);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
acc10 = fma(a1.s0, b0.s0, acc10);
acc11 = fma(a1.s0, b0.s1, acc11);
acc12 = fma(a1.s0, b0.s2, acc12);
acc13 = fma(a1.s0, b0.s3, acc13);
- acc10 = fma(a1.s1, b1.s0, acc10);
- acc11 = fma(a1.s1, b1.s1, acc11);
- acc12 = fma(a1.s1, b1.s2, acc12);
- acc13 = fma(a1.s1, b1.s3, acc13);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
acc20 = fma(a2.s0, b0.s0, acc20);
acc21 = fma(a2.s0, b0.s1, acc21);
acc22 = fma(a2.s0, b0.s2, acc22);
acc23 = fma(a2.s0, b0.s3, acc23);
- acc20 = fma(a2.s1, b1.s0, acc20);
- acc21 = fma(a2.s1, b1.s1, acc21);
- acc22 = fma(a2.s1, b1.s2, acc22);
- acc23 = fma(a2.s1, b1.s3, acc23);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
acc30 = fma(a3.s0, b0.s0, acc30);
acc31 = fma(a3.s0, b0.s1, acc31);
acc32 = fma(a3.s0, b0.s2, acc32);
acc33 = fma(a3.s0, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s1, b0.s0, acc00);
+ acc01 = fma(a0.s1, b0.s1, acc01);
+ acc02 = fma(a0.s1, b0.s2, acc02);
+ acc03 = fma(a0.s1, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s1, b0.s0, acc10);
+ acc11 = fma(a1.s1, b0.s1, acc11);
+ acc12 = fma(a1.s1, b0.s2, acc12);
+ acc13 = fma(a1.s1, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s1, b0.s0, acc20);
+ acc21 = fma(a2.s1, b0.s1, acc21);
+ acc22 = fma(a2.s1, b0.s2, acc22);
+ acc23 = fma(a2.s1, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 = fma(a3.s1, b1.s0, acc30);
- acc31 = fma(a3.s1, b1.s1, acc31);
- acc32 = fma(a3.s1, b1.s2, acc32);
- acc33 = fma(a3.s1, b1.s3, acc33);
+ acc30 = fma(a3.s1, b0.s0, acc30);
+ acc31 = fma(a3.s1, b0.s1, acc31);
+ acc32 = fma(a3.s1, b0.s2, acc32);
+ acc33 = fma(a3.s1, b0.s3, acc33);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s2, b0.s0, acc00);
+ acc01 = fma(a0.s2, b0.s1, acc01);
+ acc02 = fma(a0.s2, b0.s2, acc02);
+ acc03 = fma(a0.s2, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s2, b0.s0, acc10);
+ acc11 = fma(a1.s2, b0.s1, acc11);
+ acc12 = fma(a1.s2, b0.s2, acc12);
+ acc13 = fma(a1.s2, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s2, b0.s0, acc20);
+ acc21 = fma(a2.s2, b0.s1, acc21);
+ acc22 = fma(a2.s2, b0.s2, acc22);
+ acc23 = fma(a2.s2, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ acc30 = fma(a3.s2, b0.s0, acc30);
+ acc31 = fma(a3.s2, b0.s1, acc31);
+ acc32 = fma(a3.s2, b0.s2, acc32);
+ acc33 = fma(a3.s2, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s3, b0.s0, acc00);
+ acc01 = fma(a0.s3, b0.s1, acc01);
+ acc02 = fma(a0.s3, b0.s2, acc02);
+ acc03 = fma(a0.s3, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s3, b0.s0, acc10);
+ acc11 = fma(a1.s3, b0.s1, acc11);
+ acc12 = fma(a1.s3, b0.s2, acc12);
+ acc13 = fma(a1.s3, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s3, b0.s0, acc20);
+ acc21 = fma(a2.s3, b0.s1, acc21);
+ acc22 = fma(a2.s3, b0.s2, acc22);
+ acc23 = fma(a2.s3, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ acc30 = fma(a3.s3, b0.s0, acc30);
+ acc31 = fma(a3.s3, b0.s1, acc31);
+ acc32 = fma(a3.s3, b0.s2, acc32);
+ acc33 = fma(a3.s3, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += 4 * sizeof(float);
}
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
+ for(; i < (int)COLS_A; ++i)
{
// Load values from matrix A
- float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1238,6 +1340,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// Load values from matrix B
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0, b0.s0, acc00);
@@ -1262,6 +1365,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc32 = fma(a3, b0.s2, acc32);
acc33 = fma(a3, b0.s3, acc33);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float);
}
// Compute destination address
@@ -1375,9 +1480,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
src_addr.s1 += get_global_id(2) * src1_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- // Address boundary for the matrix A
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
// Initialize accumulators
float acc00 = 0.0f;
float acc01 = 0.0f;
@@ -1396,67 +1498,114 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// A and B src indices get incremented at the same time.
- for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
+ int i = 0;
+ for(; i <= ((int)COLS_A - 8); i += 8)
{
// Load values from matrix A
- float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
// Load values from matrix B
- float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
- float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
- float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
+ float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0.s0, b0.s0, acc00);
acc00 = fma(a0.s1, b1.s0, acc00);
acc00 = fma(a0.s2, b2.s0, acc00);
acc00 = fma(a0.s3, b3.s0, acc00);
+ acc00 = fma(a0.s4, b4.s0, acc00);
+ acc00 = fma(a0.s5, b5.s0, acc00);
+ acc00 = fma(a0.s6, b6.s0, acc00);
+ acc00 = fma(a0.s7, b7.s0, acc00);
acc01 = fma(a0.s0, b0.s1, acc01);
acc01 = fma(a0.s1, b1.s1, acc01);
acc01 = fma(a0.s2, b2.s1, acc01);
acc01 = fma(a0.s3, b3.s1, acc01);
+ acc01 = fma(a0.s4, b4.s1, acc01);
+ acc01 = fma(a0.s5, b5.s1, acc01);
+ acc01 = fma(a0.s6, b6.s1, acc01);
+ acc01 = fma(a0.s7, b7.s1, acc01);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
acc10 = fma(a0.s0, b0.s0, acc10);
acc10 = fma(a0.s1, b1.s0, acc10);
acc10 = fma(a0.s2, b2.s0, acc10);
acc10 = fma(a0.s3, b3.s0, acc10);
+ acc10 = fma(a0.s4, b4.s0, acc10);
+ acc10 = fma(a0.s5, b5.s0, acc10);
+ acc10 = fma(a0.s6, b6.s0, acc10);
+ acc10 = fma(a0.s7, b7.s0, acc10);
acc11 = fma(a0.s0, b0.s1, acc11);
acc11 = fma(a0.s1, b1.s1, acc11);
acc11 = fma(a0.s2, b2.s1, acc11);
acc11 = fma(a0.s3, b3.s1, acc11);
+ acc11 = fma(a0.s4, b4.s1, acc11);
+ acc11 = fma(a0.s5, b5.s1, acc11);
+ acc11 = fma(a0.s6, b6.s1, acc11);
+ acc11 = fma(a0.s7, b7.s1, acc11);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
acc20 = fma(a0.s0, b0.s0, acc20);
acc20 = fma(a0.s1, b1.s0, acc20);
acc20 = fma(a0.s2, b2.s0, acc20);
acc20 = fma(a0.s3, b3.s0, acc20);
+ acc20 = fma(a0.s4, b4.s0, acc20);
+ acc20 = fma(a0.s5, b5.s0, acc20);
+ acc20 = fma(a0.s6, b6.s0, acc20);
+ acc20 = fma(a0.s7, b7.s0, acc20);
acc21 = fma(a0.s0, b0.s1, acc21);
acc21 = fma(a0.s1, b1.s1, acc21);
acc21 = fma(a0.s2, b2.s1, acc21);
acc21 = fma(a0.s3, b3.s1, acc21);
+ acc21 = fma(a0.s4, b4.s1, acc21);
+ acc21 = fma(a0.s5, b5.s1, acc21);
+ acc21 = fma(a0.s6, b6.s1, acc21);
+ acc21 = fma(a0.s7, b7.s1, acc21);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
acc30 = fma(a0.s0, b0.s0, acc30);
acc30 = fma(a0.s1, b1.s0, acc30);
acc30 = fma(a0.s2, b2.s0, acc30);
acc30 = fma(a0.s3, b3.s0, acc30);
+ acc30 = fma(a0.s4, b4.s0, acc30);
+ acc30 = fma(a0.s5, b5.s0, acc30);
+ acc30 = fma(a0.s6, b6.s0, acc30);
+ acc30 = fma(a0.s7, b7.s0, acc30);
acc31 = fma(a0.s0, b0.s1, acc31);
acc31 = fma(a0.s1, b1.s1, acc31);
acc31 = fma(a0.s2, b2.s1, acc31);
acc31 = fma(a0.s3, b3.s1, acc31);
+ acc31 = fma(a0.s4, b4.s1, acc31);
+ acc31 = fma(a0.s5, b5.s1, acc31);
+ acc31 = fma(a0.s6, b6.s1, acc31);
+ acc31 = fma(a0.s7, b7.s1, acc31);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float) * 8;
}
// float size increment
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
+ for(; i < (int)COLS_A; ++i)
{
// Load values from matrix A
float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1471,6 +1620,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0, b0.s0, acc00);
@@ -1487,6 +1637,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc30 = fma(a3, b0.s0, acc30);
acc31 = fma(a3, b0.s1, acc31);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float);
}
// Compute destination address
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 5dd1f00fa6..9b3bf48bca 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -37,10 +37,8 @@ using namespace arm_compute::misc::shape_calculator;
namespace
{
-Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output, bool is_interleaved_transposed)
+Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output)
{
- const GPUTarget gpu_target = CLScheduler::get().target();
-
if(is_data_type_quantized_asymmetric(input.data_type()))
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -55,7 +53,7 @@ Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const I
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&input, &weights, &output, 1.f, is_interleaved_transposed, GEMMReshapeInfo(), gpu_target));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)));
}
return Status{};
@@ -75,12 +73,12 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c
}
CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(),
- _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
+ : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
+ _im2col_output(), _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
{
}
-void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed)
+void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
{
if(_is_quantized)
{
@@ -102,8 +100,7 @@ void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor
else
{
// Configure matrix multiply kernel
- _mm_kernel.set_target(CLScheduler::get().target());
- _mm_kernel.configure(input, weights, output, 1.f, is_interleaved_transposed);
+ _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
}
}
@@ -122,7 +119,7 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT
_im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
// Configure matrix multiply kernel
- configure_mm(&_im2col_output, weights, output, false);
+ configure_mm(&_im2col_output, weights, output);
// Allocate the output tensor for im2col once all the configure methods have been called
_im2col_output.allocator()->allocate();
@@ -133,7 +130,7 @@ void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTen
ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
// Configure matrix multiply kernel
- configure_mm(input, weights, output, false);
+ configure_mm(input, weights, output);
}
void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped)
@@ -301,7 +298,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
}
// Validate matrix multiply kernel
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output));
// Validate output stage for asymmetric quantized types
if(is_quantized)
@@ -341,7 +338,7 @@ void CLFullyConnectedLayer::run()
}
else
{
- CLScheduler::get().enqueue(_mm_kernel, !_accumulate_biases);
+ _mm_gemm.run();
}
// Accumulate biases if provided