diff options
Diffstat (limited to 'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl')
-rw-r--r-- | src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl index 8919023d4c..09b8956b68 100644 --- a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl +++ b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. + * Copyright (c) 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -117,9 +117,23 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul( uint rhs_y = block_id; // Compute LHS/RHS/DST matrix address +#ifdef REINTERPRET_INPUT_AS_3D + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y; +#else // REINTERPRET_INPUT_AS_3D lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; +#endif // REINTERPRET_INPUT_AS_3D + +#ifdef BATCHED_RHS rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z; +#else // BATCHED_RHS + rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y; +#endif // BATCHED_RHS + +#ifdef REINTERPRET_OUTPUT_AS_3D + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y; +#else // REINTERPRET_OUTPUT_AS_3D dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; +#endif // REINTERPRET_OUTPUT_AS_3D // Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for // not sliding the tensor @@ -367,11 +381,25 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( // Starting RHS coordinates uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + +#ifdef BATCHED_RHS uint rhs_y = block_id + z * rhs_h; +#else // BATCHED_RHS + uint rhs_y = block_id; +#endif // BATCHED_RHS // Compute LHS/RHS/DST matrix address +#ifdef REINTERPRET_INPUT_AS_3D + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y; +#else // REINTERPRET_INPUT_AS_3D lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; +#endif // REINTERPRET_INPUT_AS_3D + +#ifdef REINTERPRET_OUTPUT_AS_3D + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y; +#else // REINTERPRET_OUTPUT_AS_3D dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; +#endif // REINTERPRET_OUTPUT_AS_3D // Initialize the accumulators // MMUL extension accumulate the result in F32 for both F32 and F16 @@ -525,4 +553,4 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( #undef RHS_OFFSET_X #undef RHS_STEP_X } -#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE)
\ No newline at end of file +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE) |