aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl32
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp9
2 files changed, 36 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl
index 8919023d4c..09b8956b68 100644
--- a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl
+++ b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -117,9 +117,23 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul(
uint rhs_y = block_id;
// Compute LHS/RHS/DST matrix address
+#ifdef REINTERPRET_INPUT_AS_3D
+ lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y;
+#else // REINTERPRET_INPUT_AS_3D
lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+#endif // REINTERPRET_INPUT_AS_3D
+
+#ifdef BATCHED_RHS
rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+#else // BATCHED_RHS
+ rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y;
+#endif // BATCHED_RHS
+
+#ifdef REINTERPRET_OUTPUT_AS_3D
+ dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y;
+#else // REINTERPRET_OUTPUT_AS_3D
dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+#endif // REINTERPRET_OUTPUT_AS_3D
// Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for
// not sliding the tensor
@@ -367,11 +381,25 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture(
// Starting RHS coordinates
uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0;
+
+#ifdef BATCHED_RHS
uint rhs_y = block_id + z * rhs_h;
+#else // BATCHED_RHS
+ uint rhs_y = block_id;
+#endif // BATCHED_RHS
// Compute LHS/RHS/DST matrix address
+#ifdef REINTERPRET_INPUT_AS_3D
+ lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + (lhs_y + z * M) * lhs_stride_y;
+#else // REINTERPRET_INPUT_AS_3D
lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+#endif // REINTERPRET_INPUT_AS_3D
+
+#ifdef REINTERPRET_OUTPUT_AS_3D
+ dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + (dst_y + z * M) * dst_stride_y;
+#else // REINTERPRET_OUTPUT_AS_3D
dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+#endif // REINTERPRET_OUTPUT_AS_3D
// Initialize the accumulators
// MMUL extension accumulate the result in F32 for both F32 and F16
@@ -525,4 +553,4 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture(
#undef RHS_OFFSET_X
#undef RHS_STEP_X
}
-#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE) \ No newline at end of file
+#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE)
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
index 4ca5580443..4ca4b83f9c 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
@@ -41,7 +41,6 @@
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
#include "support/Cast.h"
#include "support/StringSupport.h"
-
namespace arm_compute
{
namespace opencl
@@ -101,7 +100,7 @@ Status validate_arguments(const ITensorInfo *src0,
ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k);
// Validate the reinterpreted-as-3D-case
- if (gemm_info.reinterpret_input_as_3d != 0)
+ if (gemm_info.reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m);
}
@@ -284,8 +283,12 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileCon
build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+ build_opts.add_option_if(gemm_info.reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(gemm_info.depth_output_gemm3d != 0, "-DREINTERPRET_OUTPUT_AS_3D");
+ build_opts.add_option_if(src1->num_dimensions() > 2, "-DBATCHED_RHS");
+
std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul");
- kernel_name += rhs_info.export_to_cl_image ? "_texture" : "";
+ kernel_name += _export_to_cl_image ? "_texture" : "";
// A macro guard to compile ONLY the kernel of interest
build_opts.add_option("-D" + upper_string(kernel_name));