aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp225
1 files changed, 162 insertions, 63 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp
index e70e5d5ea5..4bf0b76c3a 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp
@@ -28,6 +28,9 @@
#include "src/core/AccessWindowStatic.h"
#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/utils/helpers/float_ops.h"
+#include "support/StringSupport.h"
+
namespace arm_compute
{
namespace experimental
@@ -214,6 +217,13 @@ std::string ClGemmNativeKernelComponent::get_additional_macros() const
std::string ClGemmNativeKernelComponent::get_component_code() const
{
+ auto t_lhs_info = _blueprint->impl().get_kernel_argument_info(_lhs.arg_id);
+ auto t_rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id);
+
+ auto has_alpha = !(helpers::float_ops::is_one(_desc.alpha));
+ auto reinterpret_input_as_3d = _desc.reinterpret_input_as_3d && _desc.depth_output_gemm3d == 0;
+ auto dont_slide_b = t_rhs_info->num_dimensions() < t_lhs_info->num_dimensions();
+
std::string code = R"_(
//------------------ START KERNEL {{meta_kernel_id}} ---------------------
// IN_0(lhs) {{lhs}}
@@ -245,34 +255,49 @@ std::string ClGemmNativeKernelComponent::get_component_code() const
// Compute RHS matrix address
uint rhs_offset = {{rhs}}_offset_first_element_in_bytes + g_x * N0 * sizeof(DATA_TYPE);
+ )_";
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- rhs_offset += (g_z % MATRIX_B_DEPTH) * {{rhs}}_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- rhs_offset += g_z * {{rhs}}_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
+ if(dont_slide_b)
+ {
+ code += R"_(
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ rhs_offset += (g_z % {{MATRIX_B_DEPTH}}) * {{rhs}}_stride_z;
+ )_";
+ }
+ else
+ {
+ code += R"_(
+ rhs_offset += g_z * {{rhs}}_stride_z;
+ )_";
+ }
+ code += R"_(
REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
+ )_";
-#if defined(REINTERPRET_INPUT_AS_3D)
- // The plane (zlhs) is calculated dividing M (g_y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(g_y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, {{lhs}}_cross_plane_pad, {{lhs}}_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply lhs_stride_z by DEPTH_GEMM3D
- lhs_offset += g_z * {{lhs}}_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- lhs_offset += g_z * {{lhs}}_stride_z;
+ if(reinterpret_input_as_3d)
+ {
+ code += R"_(
+ // The plane (zlhs) is calculated dividing M (g_y * M0) by HEIGHT_GEMM3D
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(g_y, M0, PARTIAL_STORE_M0), {{HEIGHT_GEMM3D}}, {{DEPTH_GEMM3D}}, {{lhs}}_cross_plane_pad, {{lhs}}_stride_y);
-#endif // defined(REINTERPRET_INPUT_AS_3D)
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply lhs_stride_z by DEPTH_GEMM3D
+ lhs_offset += g_z * {{lhs}}_stride_z * {{DEPTH_GEMM3D}};
+ )_";
+ }
+ else
+ {
+ code += R"_(
+ // Add offset for batched GEMM
+ lhs_offset += g_z * {{lhs}}_stride_z;
+ )_";
+ }
+ code += R"_(
int i = 0;
-#if K0 > 1
- for(; i <= (K - K0); i += K0)
+#if {{K0}} > 1
+ for(; i <= (K - {{K0}}); i += {{K0}})
{
// Supported cases (M0, K0):
// 1,2 - 1,3 - 1,4 - 1,8 - 1,16
@@ -284,26 +309,26 @@ std::string ClGemmNativeKernelComponent::get_component_code() const
// 7,2 - 7,3 - 7,4 - 7,8 - 7,16
// 8,2 - 8,3 - 8,4 - 8,8 - 8,16
// Load values from LHS matrix
- LOAD_BLOCK(M0, K0, DATA_TYPE, a, {{lhs}}_ptr, lhs_offset, {{lhs}}_stride_y, zlhs);
+ LOAD_BLOCK(M0, {{K0}}, DATA_TYPE, a, {{lhs}}_ptr, lhs_offset, {{lhs}}_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(K0, N0, DATA_TYPE, b, {{rhs}}_ptr, rhs_offset, {{rhs}}_stride_y, g_zero);
+ LOAD_BLOCK({{K0}}, N0, DATA_TYPE, b, {{rhs}}_ptr, rhs_offset, {{rhs}}_stride_y, g_zero);
RHS_VFMA_M0xN0(0, a, b0, {{dst}});
RHS_VFMA_M0xN0(1, a, b1, {{dst}});
-#if K0 > 2
+#if {{K0}} > 2
RHS_VFMA_M0xN0(2, a, b2, {{dst}});
#endif // K0 > 2
-#if K0 > 3
+#if {{K0}} > 3
RHS_VFMA_M0xN0(3, a, b3, {{dst}});
#endif // K0 > 3
-#if K0 > 4
+#if {{K0}} > 4
RHS_VFMA_M0xN0(4, a, b4, {{dst}});
RHS_VFMA_M0xN0(5, a, b5, {{dst}});
RHS_VFMA_M0xN0(6, a, b6, {{dst}});
RHS_VFMA_M0xN0(7, a, b7, {{dst}});
#endif // K0 > 4
-#if K0 > 8
+#if {{K0}} > 8
RHS_VFMA_M0xN0(8, a, b8, {{dst}});
RHS_VFMA_M0xN0(9, a, b9, {{dst}});
RHS_VFMA_M0xN0(A, a, bA, {{dst}});
@@ -314,8 +339,8 @@ std::string ClGemmNativeKernelComponent::get_component_code() const
RHS_VFMA_M0xN0(F, a, bF, {{dst}});
#endif // K0 > 8
- lhs_offset += K0 * sizeof(DATA_TYPE);
- rhs_offset += K0 * {{rhs}}_stride_y;
+ lhs_offset += {{K0}} * sizeof(DATA_TYPE);
+ rhs_offset += {{K0}} * {{rhs}}_stride_y;
}
#endif // K0 > 1
// Left-over accumulations
@@ -362,44 +387,61 @@ std::string ClGemmNativeKernelComponent::get_component_code() const
}
// Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- SCALE_BLOCK(M0, DATA_TYPE, {{dst}}, ALPHA);
-#endif // defined(ALPHA)
)_";
-
- if(!_bias.is_empty())
+ if(has_alpha)
{
code += R"_(
- // Add beta*bias
-#if defined(BROADCAST_BIAS)
- __global uchar *bias_addr = {{bias}}_ptr + {{bias}}_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
-
- LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, {{bias}}_stride_y, g_zero);
-
-#ifndef UNIT_BETA
- SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
-#endif // UNIT_BIAS
-
- // c = c + bias[broadcasted]
- ADD_BLOCK_BROADCAST(M0, {{dst}}, bias0);
-
-#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = {{bias}}_ptr + {{bias}}_offset_first_element_in_bytes + (g_x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(g_y, M0,
- PARTIAL_STORE_M0)
- * {{bias}}_stride_y)
- + g_z * {{bias}}_stride_z;
-
- LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, {{bias}}_stride_y, g_zero);
-
-#ifndef UNIT_BETA
- SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
-#endif // UNIT_BIAS
-
- // c = c + bias
- ADD_BLOCK(M0, {{dst}}, bias);
+ SCALE_BLOCK(M0, DATA_TYPE, {{dst}}, {{ALPHA}});
+ )_";
+ }
-#endif // defined(BROADCAST_BIAS)
- )_";
+ if(!_bias.is_empty())
+ {
+ if(_desc.broadcast_bias)
+ {
+ code += R"_(
+ // Add beta*bias
+ __global uchar *bias_addr = {{bias}}_ptr + {{bias}}_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+ LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, {{bias}}_stride_y, g_zero);
+ )_";
+
+ if(helpers::float_ops::is_one(_desc.beta))
+ {
+ code += R"_(
+ SCALE_BLOCK(1, DATA_TYPE, bias, {{BETA}});
+ )_";
+ }
+
+ code += R"_(
+ // c = c + bias[broadcasted]
+ ADD_BLOCK_BROADCAST(M0, {{dst}}, bias0);
+ )_";
+ }
+ else
+ {
+ code += R"_(
+ // Add beta*bias
+ __global uchar *bias_addr = {{bias}}_ptr + {{bias}}_offset_first_element_in_bytes + (g_x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(g_y, M0,
+ PARTIAL_STORE_M0)
+ * {{bias}}_stride_y)
+ + g_z * {{bias}}_stride_z;
+
+ LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, {{bias}}_stride_y, g_zero);
+ )_";
+
+ if(helpers::float_ops::is_one(_desc.beta))
+ {
+ code += R"_(
+ SCALE_BLOCK(M0, DATA_TYPE, bias, {{BETA}});
+ )_";
+ }
+
+ code += R"_(
+ // c = c + bias
+ ADD_BLOCK(M0, {{dst}}, bias);
+ )_";
+ }
}
code += R"_(
@@ -409,6 +451,25 @@ std::string ClGemmNativeKernelComponent::get_component_code() const
return code.c_str();
}
+CLBuildOptions ClGemmNativeKernelComponent::generate_build_options() const
+{
+ auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id());
+ auto tile_info = _blueprint->impl().get_tile_info();
+
+ CLBuildOptions build_opts{};
+
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(t_dst_info->data_type()));
+ build_opts.add_option("-DM=" + support::cpp11::to_string(tile_info.boundaries.y()));
+ build_opts.add_option("-DN=" + support::cpp11::to_string(tile_info.boundaries.x()));
+ build_opts.add_option("-DK=" + support::cpp11::to_string(_desc.k));
+ build_opts.add_option("-DM0=" + support::cpp11::to_string(tile_info.tile_dims.y()));
+ build_opts.add_option("-DN0=" + support::cpp11::to_string(tile_info.tile_dims.x()));
+ build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(tile_info.boundaries.y() % tile_info.tile_dims.y()));
+ build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(tile_info.boundaries.x() % tile_info.tile_dims.x()));
+
+ return build_opts;
+}
+
ClGemmNativeKernelComponent::TagLUT ClGemmNativeKernelComponent::allocate_vars(SharedVarTable &vtable) const
{
TagLUT lut{};
@@ -421,6 +482,44 @@ ClGemmNativeKernelComponent::TagLUT ClGemmNativeKernelComponent::allocate_vars(S
lut["bias"] = vtable.add(_bias, ClKernelArgRuntimeDescriptor(_bias.arg_id, TensorArgType::Image_3D), "bias");
}
lut["dst"] = vtable.add(_dst, ClKernelArgRuntimeDescriptor(_dst.arg_id, TensorArgType::Image_3D), "dst");
+
+ // Local build options
+ auto t_lhs_info = _blueprint->impl().get_kernel_argument_info(_lhs.arg_id);
+ auto t_rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id);
+ auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id());
+
+ auto has_alpha = !(helpers::float_ops::is_one(_desc.alpha));
+ auto has_beta = _blueprint->impl().get_kernel_argument_info(_bias.arg_id) != nullptr;
+ auto reinterpret_input_as_3d = _desc.reinterpret_input_as_3d && _desc.depth_output_gemm3d == 0;
+ auto reinterpret_output_as_3d = !_desc.reinterpret_input_as_3d && _desc.depth_output_gemm3d != 0;
+ auto dont_slide_b = t_rhs_info->num_dimensions() < t_lhs_info->num_dimensions();
+
+ lut["K0"] = support::cpp11::to_string(_desc.rhs_info.k0);
+
+ if(has_alpha)
+ {
+ lut["ALPHA"] = float_to_string_with_full_precision(_desc.alpha);
+ }
+ if(has_beta)
+ {
+ lut["BETA"] = float_to_string_with_full_precision(_desc.beta);
+ }
+ if(dont_slide_b)
+ {
+ lut["MATRIX_B_DEPTH"] = support::cpp11::to_string(t_rhs_info->dimension(2));
+ }
+
+ if(reinterpret_output_as_3d)
+ {
+ lut["HEIGHT_GEMM3D"] = support::cpp11::to_string(t_dst_info->dimension(1));
+ lut["DEPTH_GEMM3D"] = support::cpp11::to_string(t_dst_info->dimension(2));
+ }
+ else if(reinterpret_input_as_3d)
+ {
+ lut["HEIGHT_GEMM3D"] = support::cpp11::to_string(t_lhs_info->dimension(1));
+ lut["DEPTH_GEMM3D"] = support::cpp11::to_string(t_lhs_info->dimension(2));
+ }
+
return lut;
}
} // namespace dynamic_fusion