diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped_only_rhs')
5 files changed, 1556 insertions, 0 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp new file mode 100644 index 0000000000..c4825bfbeb --- /dev/null +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp @@ -0,0 +1,577 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" + +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include <utility> + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ( + ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + switch (_target) + { + case GPUTarget::G76: + func = configs_G76.get_function(data_type); + break; + case GPUTarget::G51: + func = configs_G51.get_function(data_type); + break; + case GPUTarget::G52: + func = configs_G52.get_function(data_type); + break; + case GPUTarget::G31: + func = configs_G31.get_function(data_type); + break; + default: + func = configs_G7x.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + if (n <= 2548) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); + if (m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1); + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const bool is_workload_big = ((m * n * b) / 16) >= 2048; + + if (m == 1) + { + if (n >= 8192) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + if (n <= 204) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false); + } + } + } + else + { + const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1)); + if (is_workload_big) + { + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); + } + } + + // Get lhs_info/rhs_info in case of OpenCL image + const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1)); + if (is_workload_big) + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); + } + + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); + const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); + const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + + // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d + const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true; + + if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + { + return std::make_pair(lhs_info_img, rhs_info_img); + } + else + { + return std::make_pair(lhs_info_buf, rhs_info_buf); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if (m == 1) + { + if (r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + } + else + { + if (workload <= 274.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + if (n > 2048) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if (m == 1) + { + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); + + if (r_mk <= 0.0026f) + { + if (r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + } + else + { + if (r_mk <= 0.0148f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + } + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); + + if (workload <= 362.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + if (r_mn <= 22.6067f) + { + if (workload <= 708.8000f) + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false); + } + } + else + { + if (r_nk <= 0.0917f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + } + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + + if (m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + + if (workload <= 7449.60f) + { + if (workload <= 691.60f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); + } + else + { + if (workload <= 4155.20f) + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false); + } + } + } + else + { + if (workload <= 16300.80f) + { + if (r_mn <= 44.56f) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); + } + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (dot8_supported(CLKernelLibrary::get().get_device())) + { + if (m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + } + else + { + const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1)); + if (m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } +} + +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h new file mode 100644 index 0000000000..77c0c8d500 --- /dev/null +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H + +#include "src/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Bifrost based OpenCL GEMMReshapedOnlyRHS configuration */ +class ClGemmDefaultConfigReshapedRhsOnlyBifrost final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu); + + // Inherited overridden method + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H */ diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp new file mode 100644 index 0000000000..da3e2ec912 --- /dev/null +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -0,0 +1,755 @@ +/* + * Copyright (c) 2020-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" + +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h" + +#include <utility> + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ( + ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch (_target) + { + case GPUTarget::G78: + func = configs_G78.get_function(data_type); + break; + case GPUTarget::G710: + case GPUTarget::G610: + func = configs_G710.get_function(data_type); + break; + case GPUTarget::G715: + case GPUTarget::G615: + func = configs_G715.get_function(data_type); + break; + case GPUTarget::G77: + default: + func = configs_G77.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + if (m == 1) + { + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + + if (r_mk <= 0.0064484127797186375) + { + if (r_mn <= 0.0028273810748942196) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const unsigned int h0 = std::max(n / 4, 1U); + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0); + } + } + else + { + if (r_mk <= 0.020312500186264515) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0); + } + } + } + else + { + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + + if (workload <= 1999.2000122070312) + { + if (workload <= 747.1999816894531) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + } + else + { + if (r_mn <= 0.03348214365541935) + { + if (r_mk <= 0.028125000186264515) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); + } + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const GeMMConfigsMatrix configs_1nkb_best = { + {1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0}, {1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0}, + {1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1}, + {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_fallback = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}, + {16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = { + {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, + {369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0}, + {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = { + {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0}, + {369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0}, + {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0}, + {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = { + {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0}, + {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1}, + {49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = { + {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0}, + {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0}, + }; + + const GeMMConfigsMatrix configs_mnkb_squared_best = { + {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, + {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_squared_fallback = { + {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, + {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_best_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix *configs_best_to_use = nullptr; + const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; + + if (b == 1) + { + constexpr float ratio_m_gt_n = 10.f; + constexpr float ratio_n_gt_m = 0.1f; + constexpr unsigned int n_small_thr = 4; + const float ratio = static_cast<float>(m) / static_cast<float>(n); + + if (m == 1) + { + // We do not need fallback in this case, as we never use cl_image for the rhs tensor + configs_best_to_use = &configs_1nkb_best; + configs_fallback_to_use = &configs_1nkb_best; + } + else if (n <= n_small_thr && ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_n_small_best; + configs_fallback_to_use = &configs_mnkb_n_small_fallback; + } + else if (ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_m_gt_n_best; + configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; + } + else if (ratio < ratio_n_gt_m) + { + configs_best_to_use = &configs_mnkb_n_gt_m_best; + configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; + } + else + { + configs_best_to_use = &configs_mnkb_squared_best; + configs_fallback_to_use = &configs_mnkb_squared_fallback; + } + } + else + { + configs_best_to_use = &configs_mnkb_best_batched; + configs_fallback_to_use = &configs_mnkb_fallback_batched; + } + + GEMMLHSMatrixInfo lhs_info0; + GEMMRHSMatrixInfo rhs_info0; + GEMMLHSMatrixInfo lhs_info1; + GEMMRHSMatrixInfo rhs_info1; + + std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); + std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); + + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b, + DataType::F16); +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if (m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); + if (m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1); + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + + if (m == 1) + { + if (workload <= 278.7000f) + { + if (workload <= 7.5000f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if (r_mn <= 0.0031f) + { + if (workload <= 256.6000f) + { + if (workload <= 16.7500f) + { + if (r_nk <= 1.6671f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + if (r_mk <= 0.0027f) + { + if (r_mk <= 0.0014f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + if (workload <= 8.9500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + } + else + { + if (workload <= 14.1500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if (r_mk <= 0.0041f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + } + } + } + } + else + { + if (workload <= 363.7000f) + { + if (r_mk <= 0.0031f) + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + } + } + else + { + if (workload <= 1384.8000f) + { + if (workload <= 704.0000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if (workload <= 16761.6006f) + { + if (r_mn <= 187.1250f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if (r_mk <= 432.4630f) + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1); + } + } + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + + if (m == 1) + { + const GeMMConfigsMatrix configs_mnkb_best = { + {1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0}, + {1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}}; + + return find_lhs_rhs_info(configs_mnkb_best, m, n, k, b); + } + else + { + if (workload <= 1384.8000f) + { + if (r_nk <= 0.8333f) + { + if (r_mk <= 0.9119f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1); + } + else + { + if (r_nk <= 0.1181f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + } + else + { + if (r_mk <= 1.0013f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1); + } + } + } + else + { + if (workload <= 11404.7998f) + { + if (r_mk <= 2.2884f) + { + if (r_nk <= 0.9286f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1); + } + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1); + } + } + else + { + if (r_nk <= 1.1926f) + { + if (r_mn <= 1385.7917f) + { + return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1); + } + } + } + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if (is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G77_f32(m, n, k, b); + } +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const GeMMConfigsMatrix configs_1nkb_best = { + {1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = { + {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1}, + {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, + }; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = { + {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}, + {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0}, + {8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, {2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, + {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, + {12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0}, + {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_squared_best = { + {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, + {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1}, + }; + + const GeMMConfigsMatrix configs_mnkb_squared_fallback = { + {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, + {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, + }; + + const GeMMConfigsMatrix configs_mnkb_best_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1}, {4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1}, {24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}}; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, + {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix *configs_best_to_use = nullptr; + const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; + + if (b == 1) + { + constexpr float ratio_m_gt_n = 10.f; + constexpr float ratio_n_gt_m = 0.1f; + constexpr unsigned int n_small_thr = 4; + const float ratio = static_cast<float>(m) / static_cast<float>(n); + + if (m == 1) + { + // We do not need fallback in this case, as we never use cl_image for the rhs tensor + configs_best_to_use = &configs_1nkb_best; + configs_fallback_to_use = &configs_1nkb_best; + } + else if (n <= n_small_thr && ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_n_small_best; + configs_fallback_to_use = &configs_mnkb_n_small_best; + } + else if (ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_m_gt_n_best; + configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; + } + else if (ratio < ratio_n_gt_m) + { + configs_best_to_use = &configs_mnkb_n_gt_m_best; + configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; + } + else + { + configs_best_to_use = &configs_mnkb_squared_best; + configs_fallback_to_use = &configs_mnkb_squared_fallback; + } + } + else + { + configs_best_to_use = &configs_mnkb_best_batched; + configs_fallback_to_use = &configs_mnkb_fallback_batched; + } + + GEMMLHSMatrixInfo lhs_info0; + GEMMRHSMatrixInfo rhs_info0; + GEMMLHSMatrixInfo lhs_info1; + GEMMRHSMatrixInfo rhs_info1; + + std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); + std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); + + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b, + DataType::F16); +} + +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if (is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G78_f16(m, n, k, b); + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h new file mode 100644 index 0000000000..a0ea337eb1 --- /dev/null +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H + +#include "src/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Valhall based OpenCL GEMMReshapedOnlyRHS configuration */ +class ClGemmDefaultConfigReshapedRhsOnlyValhall final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu); + + // Inherited overridden method + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H */ diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h new file mode 100644 index 0000000000..1f0c5c2d87 --- /dev/null +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2019-2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_GPU_CL_KERNELS_GEMM_RESHAPED_ONLY_RHS_CLGEMMRESHAPEDONLYRHSKERNELCONFIG_H +#define ACL_SRC_GPU_CL_KERNELS_GEMM_RESHAPED_ONLY_RHS_CLGEMMRESHAPEDONLYRHSKERNELCONFIG_H + +#include "src/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" +#include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h" +#include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h" + +#include <memory> + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** CLGEMMReshapedOnlyRHS factory class */ +class ClGemmReshapedOnlyRhsKernelConfigurationFactory final +{ +public: + /** Static method to call the CLGEMMReshapedOnlyRHS kernel configuration class accordingly with the GPU target + * + * @param[in] gpu GPU target + * + * @return CLGEMMReshapedOnlyRHS kernel configuration class + */ + static std::unique_ptr<IClGemmKernelConfig> create(GPUTarget gpu) + { + switch (get_arch_from_target(gpu)) + { + case GPUTarget::MIDGARD: + case GPUTarget::BIFROST: + return std::make_unique<ClGemmDefaultConfigReshapedRhsOnlyBifrost>(gpu); + case GPUTarget::VALHALL: + case GPUTarget::FIFTHGEN: + return std::make_unique<ClGemmDefaultConfigReshapedRhsOnlyValhall>(gpu); + default: + ARM_COMPUTE_ERROR("Not supported GPU target"); + } + } +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif // ACL_SRC_GPU_CL_KERNELS_GEMM_RESHAPED_ONLY_RHS_CLGEMMRESHAPEDONLYRHSKERNELCONFIG_H |