aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 2761247684..674937eff0 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -265,6 +265,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Do not slide matrix B if _slide_matrix_b = false
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
+ const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST;
+
std::string kernel_name;
if(is_interleaved_transposed)
{
@@ -275,10 +277,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width));
build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
- if(data_type == DataType::F32)
+ if(is_data_type_float(data_type) && is_bifrost)
{
- GPUTarget arch_target = get_arch_from_target(gpu_target);
- kernel_name = "gemm_mm_interleaved_transposed_f32_" + string_from_target(arch_target);
+ kernel_name = "gemm_mm_interleaved_transposed_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
}
else
{
@@ -291,7 +292,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
// Create kernels according to the architecture, data type and input size.
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && is_data_type_float(data_type))
+ if(is_data_type_float(data_type) && is_bifrost)
{
kernel_name = "gemm_mm_floating_point";