aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco <gianmarco.iodice@arm.com>2018-02-15 12:35:44 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:48:33 +0000
commitae2af74ae4368004221a41e6891e0173453996ac (patch)
treea9d16fd683ee45e1caf071c0175c9d61cb99fdc3
parentd56e770e7c394d13706a21ee350e7dafe4278987 (diff)
downloadComputeLibrary-ae2af74ae4368004221a41e6891e0173453996ac.tar.gz
COMPMID-935 - Implementing Convolution with Winograd on OpenCL (Part 1)
This patch enables GEMM to execute multiple batches in parallel https://confluence.arm.com/display/MLENG/Winograd%3A+batched+GEMM Change-Id: I66222db041dd35e82af11fbb262fd1ebd3ca4b2f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/120866 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--examples/cl_sgemm.cpp2
-rw-r--r--src/core/CL/cl_kernels/gemm.cl276
-rw-r--r--src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp21
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp30
-rw-r--r--src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp15
5 files changed, 255 insertions, 89 deletions
diff --git a/examples/cl_sgemm.cpp b/examples/cl_sgemm.cpp
index 966661b9b4..fa57885450 100644
--- a/examples/cl_sgemm.cpp
+++ b/examples/cl_sgemm.cpp
@@ -198,4 +198,4 @@ private:
int main(int argc, char **argv)
{
return utils::run_example<CLSGEMMExample>(argc, argv);
-} \ No newline at end of file
+}
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 58a550f77d..cba5eea437 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -49,27 +49,35 @@
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_transpose1xW(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
uint x = get_global_id(0);
uint y = get_global_id(1);
+ uint z = get_global_id(2);
// Compute address for Matrix B - source
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
// Compute address for Matrix B transposed - destination. X and Y are swapped
uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
(x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
+ // Add offset for batched GEMM
+ dst_addr_in_bytes += z * dst_stride_z;
+
VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
@@ -90,37 +98,47 @@ __kernel void gemm_transpose1xW(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_interleave4x4(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
uint x = get_global_id(0);
uint y = get_global_id(1);
+ uint z = get_global_id(2);
- // Compute address for Matrix B - source
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
+ // Compute address for source tensor
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
// Compute address for Matrix B transposed - destination. X and Y are swapped
uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
(y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
+ // Add offset for batched GEMM
+ dst_addr_in_bytes += z * dst_stride_z;
+
+ __global uchar *input_ptr = src.ptr;
+
// Load values from Matrix A
VEC_DATA_TYPE(DATA_TYPE, 4)
- a0 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 0)));
+ a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
- a1 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 1)));
+ a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
- a2 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 2)));
+ a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
- a3 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 3)));
+ a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
@@ -166,10 +184,14 @@ __kernel void gemm_interleave4x4(IMAGE_DECLARATION(src),
*/
__kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+ int z = get_global_id(2);
// Offset
const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
@@ -177,8 +199,8 @@ __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0)
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ __global float *src_addr_a = (__global float *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global float *src_addr_b = (__global float *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
// Compute end row address for matrix B
__global float *src_end_addr_b = src_addr_b + COLS_B;
@@ -236,11 +258,17 @@ __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0)
c30 = c30 * (float4)ALPHA;
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 4x4 block
- vstore4(c00, 0, (__global float *)(offset(&dst, 0, 0)));
- vstore4(c10, 0, (__global float *)(offset(&dst, 0, 1)));
- vstore4(c20, 0, (__global float *)(offset(&dst, 0, 2)));
- vstore4(c30, 0, (__global float *)(offset(&dst, 0, 3)));
+ vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
}
/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -271,10 +299,14 @@ __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+ int z = get_global_id(2);
// Offset
const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
@@ -282,8 +314,8 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ __global float *src_addr_a = (__global float *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global float *src_addr_b = (__global float *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
// Compute end row address for matrix B
__global float *src_end_addr_b = src_addr_b + COLS_B;
@@ -458,11 +490,17 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = c33 * ALPHA;
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 4x4 block
- vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(offset(&dst, 0, 0)));
- vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(offset(&dst, 0, 1)));
- vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(offset(&dst, 0, 2)));
- vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(offset(&dst, 0, 3)));
+ vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
}
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -494,10 +532,14 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+ int z = get_global_id(2);
// Offset
const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
@@ -505,8 +547,8 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global half *src_addr_a = (__global half *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global half *src_addr_b = (__global half *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ __global half *src_addr_a = (__global half *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global half *src_addr_b = (__global half *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
// Compute end row address for matrix B
__global half *src_end_addr_b = src_addr_b + COLS_B;
@@ -564,11 +606,17 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
c30 = c30 * (half8)ALPHA;
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 4x8 block
- vstore8(c00, 0, (__global half *)(offset(&dst, 0, 0)));
- vstore8(c10, 0, (__global half *)(offset(&dst, 0, 1)));
- vstore8(c20, 0, (__global half *)(offset(&dst, 0, 2)));
- vstore8(c30, 0, (__global half *)(offset(&dst, 0, 3)));
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
}
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -603,10 +651,14 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+ int z = get_global_id(2);
// Offset
const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
@@ -614,8 +666,8 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global char *src_addr_a = src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes;
- __global char *src_addr_b = src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes;
+ __global char *src_addr_a = src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ __global char *src_addr_b = src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes;
// Compute end row address for matrix B
__global char *src_end_addr_b = src_addr_b + COLS_B;
@@ -667,11 +719,17 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 16x4 block
- vstore16(c00_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
- vstore16(c10_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
- vstore16(c20_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
- vstore16(c30_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
+ vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
+ vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
+ vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
+ vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
}
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
@@ -704,10 +762,14 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+ int z = get_global_id(2);
// Offset
const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
@@ -715,8 +777,8 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global short *src_addr_a = (__global short *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global short *src_addr_b = (__global short *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ __global short *src_addr_a = (__global short *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global short *src_addr_b = (__global short *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
// Compute end row address for matrix B
__global short *src_end_addr_b = src_addr_b + COLS_B;
@@ -759,11 +821,17 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 8x4 block
- vstore8(c00_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
- vstore8(c10_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
- vstore8(c20_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
- vstore8(c30_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
+ vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
+ vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
+ vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
+ vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
}
#endif // defined(FIXED_POINT_POSITION)
#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
@@ -799,7 +867,10 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -812,6 +883,10 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(DATA_TYPE);
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
VECTOR_TYPE acc0 = 0.0f;
@@ -895,32 +970,38 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += get_global_id(2) * dst_stride_z;
+
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc0 = acc0 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc0, 0, (__global DATA_TYPE *)(offset(&dst, 0, 0)));
+ (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if defined(ALPHA)
acc1 = acc1 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc1, 0, (__global DATA_TYPE *)(offset(&dst, 0, 1)));
+ (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if defined(ALPHA)
acc2 = acc2 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc2, 0, (__global DATA_TYPE *)(offset(&dst, 0, 2)));
+ (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if defined(ALPHA)
acc3 = acc3 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc3, 0, (__global DATA_TYPE *)(offset(&dst, 0, 3)));
+ (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
#endif // defined(DATA_TYPE)
@@ -954,7 +1035,10 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -967,6 +1051,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
// Update address for matrix B
src_addr.s1 += idx * sizeof(float);
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+ // For convolution layer we do not want to slide the matrix B along Z
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+
// Address boundary for matrix A
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
@@ -1112,8 +1202,14 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc03 = acc03 * ALPHA;
#endif // defined(ALPHA)
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += get_global_id(2) * dst_stride_z;
+
float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
- vstore4(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
+ vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if defined(ALPHA)
@@ -1123,7 +1219,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc13 = acc13 * ALPHA;
#endif // defined(ALPHA)
float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
- vstore4(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
+ vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if defined(ALPHA)
@@ -1133,7 +1229,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc23 = acc23 * ALPHA;
#endif // defined(ALPHA)
float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
- vstore4(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
+ vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if defined(ALPHA)
@@ -1143,7 +1239,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc33 = acc33 * ALPHA;
#endif // defined(ALPHA)
float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
- vstore4(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
+ vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
@@ -1177,7 +1273,10 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
// Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1191,6 +1290,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(float);
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+ // For convolution layer we do not want to slide the matrix B along Z
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+
// Address boundary for the matrix A
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
@@ -1308,20 +1413,26 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += get_global_id(2) * dst_stride_z;
+
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc00 = acc00 * ALPHA;
acc01 = acc01 * ALPHA;
#endif // defined(ALPHA)
float2 acc0 = ((float2)(acc00, acc01));
- vstore2(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
+ vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if defined(ALPHA)
acc10 = acc10 * ALPHA;
acc11 = acc11 * ALPHA;
#endif // defined(ALPHA)
float2 acc1 = ((float2)(acc10, acc11));
- vstore2(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
+ vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if defined(ALPHA)
@@ -1329,7 +1440,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc21 = acc21 * ALPHA;
#endif // defined(ALPHA)
float2 acc2 = ((float2)(acc20, acc21));
- vstore2(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
+ vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if defined(ALPHA)
@@ -1337,7 +1448,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc31 = acc31 * ALPHA;
#endif // defined(ALPHA)
float2 acc3 = (float2)(acc30, acc31);
- vstore2(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
+ vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
@@ -1371,7 +1482,10 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1384,6 +1498,10 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(char);
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
short8 acc00 = 0;
@@ -1475,33 +1593,39 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+ // Add offset for batched GEMM
+ dst_addr += get_global_id(2) * dst_stride_z;
+
// Multiply by the weight of matrix product and store the result
char16 acc_qs8;
acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
#if defined(ALPHA)
acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
+ vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
#if defined(ALPHA)
acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
+ vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
#if defined(ALPHA)
acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
+ vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
#if defined(ALPHA)
acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
+ vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
@@ -1534,7 +1658,10 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
*/
__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z)
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1547,6 +1674,10 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(short);
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
int8 acc0 = 0;
@@ -1622,33 +1753,36 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
// Multiply by the weight of matrix product and store the result
short8 acc_qs16;
acc_qs16 = convert_short8_sat(acc0);
#if defined(ALPHA)
acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
+ vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
acc_qs16 = convert_short8_sat(acc1);
#if defined(ALPHA)
acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
+ vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
acc_qs16 = convert_short8_sat(acc2);
#if defined(ALPHA)
acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
+ vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
acc_qs16 = convert_short8_sat(acc3);
#if defined(ALPHA)
acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
+ vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
#endif // defined(FIXED_POINT_POSITION)
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 241dd8549d..d12255ff24 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -80,8 +80,12 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
output_access.set_valid_region(win, input->valid_region());
}
+ // Collapse along the Z direction
+ // This collapse needs to be here in order to tune the Z dimension of LWS
+ Window collapsed = win.collapse(win, Window::DimZ);
+
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+ return std::make_pair(err, collapsed);
}
} // namespace
@@ -136,6 +140,10 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(2));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(3));
}
Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
@@ -160,15 +168,14 @@ void CLGEMMInterleave4x4Kernel::run(const Window &window, cl::CommandQueue &queu
*
* After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
*/
- Window in_slice = window.first_slice_window_2D();
- Window out_slice = window.first_slice_window_2D();
+ Window slice = window.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, in_slice);
- add_2D_tensor_argument(idx, _output, out_slice);
- enqueue(queue, *this, in_slice, _lws_hint);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
+ enqueue(queue, *this, slice, _lws_hint);
}
- while(window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(out_slice));
+ while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 3143075a9d..6655d12d7e 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -158,8 +158,17 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
}
+ // Collapse along the Z direction
+ // This collapse needs to be here in order to tune the Z dimension of LWS
+ Window collapsed = win;
+ if(input1->num_dimensions() > 1)
+ {
+ const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(input1->num_dimensions() - 1), 2u);
+ collapsed = win.collapse(win, dimension_to_collapse);
+ }
+
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+ return std::make_pair(err, collapsed);
}
} // namespace
@@ -286,6 +295,10 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(2));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(3));
+ _config_id += "_";
_config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
}
@@ -312,7 +325,13 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- Window slice = window.first_slice_window_2D();
+ if(_input1->info()->num_dimensions() < 3)
+ {
+ // The stride_z for matrix B must be zero if we do not slice
+ ARM_COMPUTE_ERROR_ON(_input1->info()->strides_in_bytes()[3] != 0);
+ }
+
+ Window slice = window.first_slice_window_3D();
Window slice_matrix_b = slice;
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
@@ -322,7 +341,7 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
{
Window slice_b = slice;
// Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
- // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ // This scenario can happen when the matrix multiplication is used to perform a convolution operation
if(_input1->info()->num_dimensions() < 3)
{
slice_b = slice_matrix_b;
@@ -332,7 +351,10 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
add_2D_tensor_argument(idx, _input0, slice);
add_2D_tensor_argument(idx, _input1, slice_b);
add_2D_tensor_argument(idx, _output, slice);
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[3]));
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[3]));
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[3]));
enqueue(queue, *this, slice, _lws_hint);
}
- while(window.slide_window_slice_2D(slice));
+ while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp b/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
index 24d218760e..5489fde818 100644
--- a/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
@@ -86,8 +86,11 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), input->tensor_shape()));
}
+ // Collapse along the Z direction
+ Window collapsed = win.collapse(win, Window::DimZ);
+
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+ return std::make_pair(err, collapsed);
}
} // namespace
@@ -151,15 +154,15 @@ void CLGEMMTranspose1xWKernel::run(const Window &window, cl::CommandQueue &queue
out_window.set(Window::DimX, window.y());
out_window.set(Window::DimY, window.x());
- Window in_slice = window.first_slice_window_2D();
- Window out_slice = out_window.first_slice_window_2D();
+ Window in_slice = window.first_slice_window_3D();
+ Window out_slice = out_window.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, in_slice);
- add_2D_tensor_argument(idx, _output, out_slice);
+ add_3D_tensor_argument(idx, _input, in_slice);
+ add_3D_tensor_argument(idx, _output, out_slice);
enqueue(queue, *this, in_slice, _lws_hint);
}
- while(window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
+ while(window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
}