diff options
-rw-r--r-- | src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl | 32 | ||||
-rw-r--r-- | src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp | 9 |
2 files changed, 36 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl index 8919023d4c..09b8956b68 100644 --- a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl +++ b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. + * Copyright (c) 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -117,9 +117,23 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul( uint rhs_y = block_id; // Compute LHS/RHS/DST matrix address +#ifdef REINTERPRET_INPUT_AS_3D + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y; +#else // REINTERPRET_INPUT_AS_3D lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; +#endif // REINTERPRET_INPUT_AS_3D + +#ifdef BATCHED_RHS rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z; +#else // BATCHED_RHS + rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y; +#endif // BATCHED_RHS + +#ifdef REINTERPRET_OUTPUT_AS_3D + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y; +#else // REINTERPRET_OUTPUT_AS_3D dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; +#endif // REINTERPRET_OUTPUT_AS_3D // Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for // not sliding the tensor @@ -367,11 +381,25 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( // Starting RHS coordinates uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + +#ifdef BATCHED_RHS uint rhs_y = block_id + z * rhs_h; +#else // BATCHED_RHS + uint rhs_y = block_id; +#endif // BATCHED_RHS // Compute LHS/RHS/DST matrix address +#ifdef REINTERPRET_INPUT_AS_3D + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y; +#else // REINTERPRET_INPUT_AS_3D lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; +#endif // REINTERPRET_INPUT_AS_3D + +#ifdef REINTERPRET_OUTPUT_AS_3D + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y; +#else // REINTERPRET_OUTPUT_AS_3D dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; +#endif // REINTERPRET_OUTPUT_AS_3D // Initialize the accumulators // MMUL extension accumulate the result in F32 for both F32 and F16 @@ -525,4 +553,4 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( #undef RHS_OFFSET_X #undef RHS_STEP_X } -#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE)
\ No newline at end of file +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE) diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp index 4ca5580443..4ca4b83f9c 100644 --- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp @@ -41,7 +41,6 @@ #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" #include "support/Cast.h" #include "support/StringSupport.h" - namespace arm_compute { namespace opencl @@ -101,7 +100,7 @@ Status validate_arguments(const ITensorInfo *src0, ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k); // Validate the reinterpreted-as-3D-case - if (gemm_info.reinterpret_input_as_3d != 0) + if (gemm_info.reinterpret_input_as_3d) { ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m); } @@ -284,8 +283,12 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileCon build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + build_opts.add_option_if(gemm_info.reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D"); + build_opts.add_option_if(gemm_info.depth_output_gemm3d != 0, "-DREINTERPRET_OUTPUT_AS_3D"); + build_opts.add_option_if(src1->num_dimensions() > 2, "-DBATCHED_RHS"); + std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul"); - kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; + kernel_name += _export_to_cl_image ? "_texture" : ""; // A macro guard to compile ONLY the kernel of interest build_opts.add_option("-D" + upper_string(kernel_name)); |