From 856f66e6c61b77d03f754cd0fa8439891f0e4aca Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 22 Apr 2021 21:13:21 +0100 Subject: Port CLGEMM to memory injecting interface Moves the following kernels: - CLGEMMMatrixMultiplyKernel - CLGEMMMatrixMultiplyNativeKernel - CLGEMMMatrixMultipluReshapedKernel - CLGEMMMatrixMultiplyReshapedOnlyRHSKernel Moves the following functions - CLGEMM Introduces facilities to easy handling of auxiliary temporary buffers under then new run interface. Such are: - CLAuxTensorHandler: That allows wrapping of workspace buffers memory to CLBuffer objects - Ability to inject TensorInfo to allocator without transferring ownership. This reduce the copy overhead if needed. Resolves: COMPMID-4188 Signed-off-by: Georgios Pinitas Change-Id: I7055435d831b05b749b26302082e4ac45f26dfb0 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5498 Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins --- src/core/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 116 +++++ src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h | 95 ++++ src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h | 123 +++++ .../native/ClGemmDefaultConfigNativeBifrost.cpp | 246 +++++++++ .../gemm/native/ClGemmDefaultConfigNativeBifrost.h | 62 +++ .../native/ClGemmDefaultConfigNativeMidgard.cpp | 73 +++ .../gemm/native/ClGemmDefaultConfigNativeMidgard.h | 57 +++ .../native/ClGemmDefaultConfigNativeValhall.cpp | 168 ++++++ .../gemm/native/ClGemmDefaultConfigNativeValhall.h | 59 +++ .../kernels/gemm/native/ClGemmNativeKernelConfig.h | 71 +++ .../ClGemmDefaultConfigReshapedBifrost.cpp | 356 +++++++++++++ .../reshaped/ClGemmDefaultConfigReshapedBifrost.h | 64 +++ .../ClGemmDefaultConfigReshapedValhall.cpp | 538 +++++++++++++++++++ .../reshaped/ClGemmDefaultConfigReshapedValhall.h | 61 +++ .../gemm/reshaped/ClGemmReshapedKernelConfig.h | 69 +++ .../ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp | 518 +++++++++++++++++++ .../ClGemmDefaultConfigReshapedRhsOnlyBifrost.h | 67 +++ .../ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 570 +++++++++++++++++++++ .../ClGemmDefaultConfigReshapedRhsOnlyValhall.h | 61 +++ .../ClGemmDefaultReshapedRhsOnlyBifrost.cpp | 518 +++++++++++++++++++ .../ClGemmDefaultReshapedRhsOnlyValhall.cpp | 570 +++++++++++++++++++++ .../ClGemmReshapedOnlyRhsKernelConfig.h | 69 +++ 22 files changed, 4531 insertions(+) create mode 100644 src/core/gpu/cl/kernels/gemm/ClGemmHelpers.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h create mode 100644 src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h create mode 100644 src/core/gpu/cl/kernels/gemm/native/ClGemmNativeKernelConfig.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped/ClGemmReshapedKernelConfig.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyBifrost.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyValhall.cpp create mode 100644 src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h (limited to 'src/core/gpu/cl/kernels/gemm') diff --git a/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.cpp new file mode 100644 index 0000000000..0a8ba971ed --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/OpenCL.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +std::pair configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, + bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image) +{ + ARM_COMPUTE_ERROR_ON(m0 == 0 || n0 == 0); + v0 = std::max(std::min(static_cast(m / m0), static_cast(v0)), static_cast(1)); + h0 = std::max(std::min(static_cast(n / n0), static_cast(h0)), static_cast(1)); + + const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave); + const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image); + + return std::make_pair(lhs_info, rhs_info); +} + +std::pair select_lhs_rhs_info(std::pair info_img, + std::pair info_buf, + unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, data_type); + const TensorShape shape = misc::shape_calculator::compute_rhs_reshaped_shape(tensor_rhs_info, info_img.second); + const TensorInfo tensor_reshaped_info(shape, 1, data_type); + + if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, info_img.second))) + { + return info_img; + } + else + { + return info_buf; + } +} + +void update_padding_for_cl_image(ITensorInfo *tensor) +{ + constexpr unsigned int num_floats_per_pixel = 4; + + const unsigned int stride_y_in_elements = tensor->strides_in_bytes()[1] / tensor->element_size(); + const unsigned int pixel_alignment = get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()); + + ARM_COMPUTE_ERROR_ON_MSG(pixel_alignment == 0, "Cannot retrieve cl_image pitch alignment"); + if(pixel_alignment == 0) + { + return; + } + + const unsigned int row_pitch_alignment = pixel_alignment * num_floats_per_pixel; + const unsigned int round_up_width = ((stride_y_in_elements + row_pitch_alignment - 1) / row_pitch_alignment) * row_pitch_alignment; + const unsigned int padding = round_up_width - stride_y_in_elements; + + tensor->extend_padding(PaddingSize(0, padding, 0, 0)); +} + +Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info) +{ + if(rhs_info.export_to_cl_image) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(&tensor_reshaped_info, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment"); + + // Check the width and height of the output tensor. + // Since we cannot create a 3d image from a buffer, the third dimension is collapsed on the second dimension + const size_t max_image_w = CLKernelLibrary::get().get_device().getInfo(); + const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.tensor_shape()[0] > max_image_w * 4, "Not supported width for cl_image"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.tensor_shape()[1] * tensor_reshaped_info.tensor_shape()[2] > max_image_h, "Not supported height for cl_image"); + } + + return Status{}; +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h new file mode 100644 index 0000000000..3fce8c9173 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_HELPERS_H +#define ARM_COMPUTE_CL_GEMM_HELPERS_H + +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + * + * @param[in] m Number of rows (M) in the LHS matrix not reshaped + * @param[in] n Number of columns (N) in the RHS matrix not reshaped + * @param[in] m0 Number of rows processed by each thread/work-item + * @param[in] n0 Number of columns processed by each thread/work-item + * @param[in] k0 Number of inner accumulation performed by each thread/work-item + * @param[in] v0 Number of vertical blocks of size (m0xk0) stored on the same output row + * @param[in] h0 Number of horizontal blocks of size (k0xn0) stored on the same output row + * @param[in] lhs_interleave True if the v0 (m0xk0) blocks have to be interleaved in the output row + * @param[in] rhs_interleave True if the h0 (k0xn0) blocks have to be interleaved in the output row + * @param[in] lhs_transpose True if the (m0xk0) block has to be transposed before been stored + * @param[in] rhs_transpose True if the (k0xn0) block has to be transposed before been stored + * @param[in] export_to_cl_image (Optional) True if the RHS reshaped matrix has to be exported to cl_image + * + * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + */ +std::pair configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, + bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image = false); + +/** Select @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + * + * This function accepts two pairs of GEMMLHSMatrixInfo/GEMMRHSMatrixInfo where only the first is with cl_image2d support, + * and selects the valid one validating the GEMMRHSMatrixInfo. If the validation passes, the functions will return + * the first GEMMLHSMatrixInfo/GEMMRHSMatrixInfo pair with cl_image2d support. + * + * @param[in] info_img GEMMLHSMatrixInfo/GEMMRHSMatrixInfo with cl_image2d support + * @param[in] info_buf GEMMLHSMatrixInfo/GEMMRHSMatrixInfo to fall-back if cl_image2d cannot be used + * @param[in] n Number of columns (N) in the RHS matrix not reshaped + * @param[in] k Number of rows (K) in the RHS matrix not reshaped + * @param[in] b Batch size + * @param[in] data_type Data type + * + * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + */ +std::pair select_lhs_rhs_info(std::pair info_img, + std::pair info_buf, + unsigned int n, unsigned int k, unsigned int b, DataType data_type); + +/** Update padding required to export the OpenCL buffer to OpenCL image2d + * + * @param[in,out] tensor ITensorInfo of the tensor required to be exported to OpenCL image2d + */ +void update_padding_for_cl_image(ITensorInfo *tensor); + +/** Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix + * + * @param[in] tensor_reshaped_info TensorInfo for the RHS reshaped matrix + * @param[in] rhs_info @ref GEMMRHSMatrixInfo + * + * @return Status reporting if we can use the image2d OpenCL object on the RHS reshaped matrix + */ +Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info); +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_HELPERS_H */ diff --git a/src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h b/src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h new file mode 100644 index 0000000000..a49836cfda --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_ICL_GEMM_KERNEL_CONFIG_H +#define ARM_COMPUTE_ICL_GEMM_KERNEL_CONFIG_H + +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/Types.h" +#include "src/core/common/Macros.h" + +#include +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Basic container for the OpenCL GEMM configuration functions */ +template +class CLGEMMConfigArray +{ +public: + /** Alias for F32 index */ + static constexpr size_t DT_F32 = 0; + /** Alias for F16 index */ + static constexpr size_t DT_F16 = 1; + /** Alias for Int8 index */ + static constexpr size_t DT_INT8 = 2; + + /** Constructor + * + * @param[in] func_f32 Function to call for GEMM F32 + * @param[in] func_f16 Function to call for GEMM F16 + * @param[in] func_int8 Function to call for GEMM Int8 (QASYMM8, QASYMM8_SIGNED, QSYMM8_PER_CHANNEL) + * + */ + CLGEMMConfigArray(T func_f32, T func_f16, T func_int8) + : _configs{ func_f32, func_f16, func_int8 } + { + } + + /** Method to return the GEMM configuration function based on data type + * + * @param[in] data_type Input data type + * + * @return the valid function otherwise it returns nullptr if the data type is not valid + */ + T get_function(DataType data_type) + { + switch(data_type) + { + case DataType::F32: + return _configs.at(DT_F32); + case DataType::F16: + return _configs.at(DT_F16); + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + return _configs.at(DT_INT8); + default: + return nullptr; + } + } + +private: + std::array _configs; +}; + +/** Basic interface for the GEMM kernel configuration */ +class IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] arch GPU target + */ + IClGemmKernelConfig(GPUTarget arch) + : _target(arch) + { + } + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClGemmKernelConfig); + /** Virtual destructor */ + virtual ~IClGemmKernelConfig() = default; + /** Given M, N, K and B, this method returns the @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo to be used + * + * @param[in] m Number of rows LHS matrix + * @param[in] n Number of columns RHS matrix + * @param[in] k Number of columns LHS matrix or number of rows RHS matrix + * @param[in] b Batch size + * @param[in] data_type Data type + */ + virtual std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) = 0; + +protected: + GPUTarget _target; +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_ICL_GEMM_KERNEL_CONFIG_H */ diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp new file mode 100644 index 0000000000..9d11006703 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +ClGemmDefaultConfigNativeBifrost::ClGemmDefaultConfigNativeBifrost(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigNativeBifrost::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_G71(&ClGemmDefaultConfigNativeBifrost::configure_G71_f32, + &ClGemmDefaultConfigNativeBifrost::configure_G71_f32, // We use the F32 heuristic + &ClGemmDefaultConfigNativeBifrost::configure_G71_u8); + + CLGEMMConfigArray configs_G76(&ClGemmDefaultConfigNativeBifrost::configure_G76_f32, + &ClGemmDefaultConfigNativeBifrost::configure_G76_f32, // We use the F32 heuristic + &ClGemmDefaultConfigNativeBifrost::configure_G76_u8); + + CLGEMMConfigArray configs_G7x(&ClGemmDefaultConfigNativeBifrost::configure_default_f32, + &ClGemmDefaultConfigNativeBifrost::configure_default_f32, // We use the F32 heuristic + &ClGemmDefaultConfigNativeBifrost::configure_default_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G76: + func = configs_G76.get_function(data_type); + break; + case GPUTarget::G71: + func = configs_G71.get_function(data_type); + break; + default: + func = configs_G7x.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 8192) + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1, false, false, false, false); + } +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 16384) + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false); + } + } + else + { + if(m < 64) + { + return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false); + } + } + } + else + { + if(m == 1) + { + if(n < 8192) + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false); + } + } +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n > 4196) + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1, false, false, false, false); + } + else + { + if(k < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1, false, false, false, false); + } + else if(k >= 2048 && k < 16384) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1, false, false, false, false); + } + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1, false, false, false, false); + } +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 16384) + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false); + } + } + else + { + if(m < 64) + { + return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false); + } + } +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1, false, false, false, false); +} + +std::pair ClGemmDefaultConfigNativeBifrost::configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false); +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h new file mode 100644 index 0000000000..385b96e40e --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_BIFROST_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_BIFROST_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Bifrost based OpenCL GEMMNative configuration */ +class ClGemmDefaultConfigNativeBifrost final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigNativeBifrost(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_BIFROST_H */ diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.cpp b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.cpp new file mode 100644 index 0000000000..e3c129e3be --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +ClGemmDefaultConfigNativeMidgard::ClGemmDefaultConfigNativeMidgard(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigNativeMidgard::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigNativeMidgard::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_default(nullptr, + nullptr, + &ClGemmDefaultConfigNativeMidgard::default_q8); + + auto func = configs_default.get_function(data_type); + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigNativeMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + const unsigned int m0 = std::min(m, static_cast(4)); + const unsigned int n0 = std::min(n, static_cast(4)); + + return configure_lhs_rhs_info(m, n, m0, n0, 2, 1, 1, false, false, false, false); +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h new file mode 100644 index 0000000000..0ff5471f7c --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_MIDGARD_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_MIDGARD_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Midgard based OpenCL GEMMNative configuration */ +class ClGemmDefaultConfigNativeMidgard final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigNativeMidgard(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_MIDGARD_H */ diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp new file mode 100644 index 0000000000..92767aca52 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +ClGemmDefaultConfigNativeValhall::ClGemmDefaultConfigNativeValhall(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigNativeValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigNativeValhall::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_default(&ClGemmDefaultConfigNativeValhall::configure_G77_f32, + &ClGemmDefaultConfigNativeValhall::configure_G77_f16, + &ClGemmDefaultConfigNativeValhall::configure_G77_u8); + + auto func = configs_default.get_function(data_type); + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigNativeValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 8192) + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1, false, false, false, false); + } +} + +std::pair ClGemmDefaultConfigNativeValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 8192) + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 2, 1, 1, false, false, false, false); + } +} + +std::pair ClGemmDefaultConfigNativeValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(m == 1) + { + if(n < 2048) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false); + } + else if(n >= 2048 && n < 16384) + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false); + } + } + else + { + if(m < 64) + { + return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false); + } + } + } + else + { + if(m == 1) + { + if(n < 8192) + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false); + } + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h new file mode 100644 index 0000000000..17e4c9d339 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_VALHALL_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_VALHALL_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Valhall based OpenCL GEMMNative configuration */ +class ClGemmDefaultConfigNativeValhall final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigNativeValhall(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_NATIVE_VALHALL_H */ diff --git a/src/core/gpu/cl/kernels/gemm/native/ClGemmNativeKernelConfig.h b/src/core/gpu/cl/kernels/gemm/native/ClGemmNativeKernelConfig.h new file mode 100644 index 0000000000..ff6a0128af --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/native/ClGemmNativeKernelConfig.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_NATIVE_KERNEL_CONFIGURATION_H +#define ARM_COMPUTE_CL_GEMM_NATIVE_KERNEL_CONFIGURATION_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h" +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.h" +#include "src/core/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** CLGEMMNative factory class */ +class ClGemmNativeKernelConfigurationFactory final +{ +public: + /** Static method to construct CLGEMMNative kernel object accordingly with the GPU target + * + * @param[in] gpu GPU target + * + * @return CLGEMMNative kernel configuration class + */ + static std::unique_ptr create(GPUTarget gpu) + { + switch(get_arch_from_target(gpu)) + { + case GPUTarget::MIDGARD: + return std::make_unique(gpu); + case GPUTarget::BIFROST: + return std::make_unique(gpu); + case GPUTarget::VALHALL: + return std::make_unique(gpu); + default: + ARM_COMPUTE_ERROR("Not supported GPU target"); + } + } +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /*ARM_COMPUTE_CL_GEMM_NATIVE_KERNEL_CONFIGURATION_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp new file mode 100644 index 0000000000..b030913a87 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray configs_G7x(&ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8); + + CLGEMMConfigArray configs_G52(&ClGemmDefaultConfigReshapedBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8); + + CLGEMMConfigArray configs_G76(&ClGemmDefaultConfigReshapedBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G76: + func = configs_G76.get_function(data_type); + break; + case GPUTarget::G52: + func = configs_G52.get_function(data_type); + break; + default: + func = configs_G7x.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false); + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true); + } + } + else + { + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true); + } + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(workload <= 274.4000f) + { + if(r_nk <= 0.7461f) + { + if(r_mn <= 21.1667f) + { + return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_mk <= 17.3926f) + { + if(workload <= 542.4000f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_nk <= 0.5463f) + { + if(workload <= 11767.6001f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(workload <= 323.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false); + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + // Get lhs_info/rhs_info in case of OpenCL buffer + if(n <= 4) + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true); + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true); + } + + // Get lhs_info/rhs_info in case of OpenCL image + // Condition on the GPU workload + if((m / 4) * (n / 4) >= 2560) + { + // Big workload + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true); + } + else + { + // Small workload + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true); + } + + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); + const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); + const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + + // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d + const bool use_cl_image2d = (n <= 4) ? false : true; + + if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + { + return std::make_pair(lhs_info_img, rhs_info_img); + } + else + { + return std::make_pair(lhs_info_buf, rhs_info_buf); + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + + if(workload <= 1595.2000f) + { + if(r_mk <= 2.1044f) + { + if(workload <= 870.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false); + } +} + +std::pair ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true); + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h new file mode 100644 index 0000000000..52e6ce3f48 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_BIFROST_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_BIFROST_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Bifrost based OpenCL GEMMReshaped configuration */ +class ClGemmDefaultConfigReshapedBifrost final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_BIFROST_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp new file mode 100644 index 0000000000..57e42c92b3 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp @@ -0,0 +1,538 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +ClGemmDefaultConfigReshapedValhall::ClGemmDefaultConfigReshapedValhall(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray configs_G77(&ClGemmDefaultConfigReshapedValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedValhall::configure_G77_f16, + &ClGemmDefaultConfigReshapedValhall::configure_G77_u8); + + CLGEMMConfigArray configs_G78(&ClGemmDefaultConfigReshapedValhall::configure_G78_f32, + &ClGemmDefaultConfigReshapedValhall::configure_G78_f16, + &ClGemmDefaultConfigReshapedValhall::configure_G77_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G78: + func = configs_G78.get_function(data_type); + break; + case GPUTarget::G77: + default: + func = configs_G77.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1); + } +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0); + + if(r_mk <= 0.11824845522642136) + { + if(workload <= 880.0) + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0); + } + else + { + if(r_nk <= 0.42521367967128754) + { + if(workload <= 1726.4000244140625) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + else + { + if(workload <= 1241.6000366210938) + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0); + } + } + } + } + else + { + if(workload <= 11404.7998046875) + { + if(r_mk <= 1.0126488208770752) + { + if(r_mn <= 2.545312523841858) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0); + } + } + else + { + if(workload <= 2881.199951171875) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } + else + { + if(r_nk <= 0.5765306055545807) + { + if(r_mn <= 6.010416746139526) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(workload <= 1288.0000f) + { + if(workload <= 505.6000f) + { + if(r_mn <= 0.4466f) + { + if(r_nk <= 0.2384f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0); + } + } + else + { + if(r_mn <= 0.2250f) + { + if(r_mn <= 0.1599f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + else + { + if(r_mk <= 0.7609f) + { + if(r_mn <= 2.5453f) + { + if(workload <= 1089.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1); + } + } + } + } + else + { + if(workload <= 5434.4001f) + { + if(workload <= 1603.2000f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + if(r_nk <= 0.6192f) + { + if(r_mn <= 16.1016f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + if(workload <= 2750.0000f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + if(r_mk <= 6.3151f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + } + } + else + { + if(r_mk <= 0.0387f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1); + } + else + { + if(r_mk <= 2.5859f) + { + if(r_mk <= 0.2734f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + } + } + } + else + { + if(r_mk <= 25.7500f) + { + if(r_mk <= 0.3615f) + { + if(r_mn <= 0.0913f) + { + if(r_mk <= 0.0683f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1); + } + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + else + { + if(workload <= 11174.3999f) + { + if(r_mk <= 0.8047f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + if(workload <= 7185.5999f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1); + } + } + } + else + { + if(workload <= 17917.5000f) + { + if(r_mk <= 1.5078f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1); + } + } + else + { + if(workload <= 34449.6016f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1); + } + } + } + } + } + else + { + if(r_mk <= 331.1111f) + { + if(workload <= 53397.5996f) + { + if(r_mn <= 57.8063f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1); + } + } + else + { + if(r_nk <= 0.9211f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1); + } + } + } + else + { + if(workload <= 38070.4004f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(workload <= 801.6000f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1); + } + else + { + if(r_mn <= 0.1211f) + { + if(workload <= 3296.0000f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + if(r_nk <= 1.0625f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1); + } + } + } + else + { + if(workload <= 5068.8000f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1); + } + else + { + if(r_nk <= 0.2361f) + { + if(workload <= 12630.0000f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1); + } + } + else + { + if(workload <= 178790.3984f) + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1); + } + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1); + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h new file mode 100644 index 0000000000..588cd64e0e --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_VALHALL_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_VALHALL_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Valhall based OpenCL GEMMReshaped configuration */ +class ClGemmDefaultConfigReshapedValhall final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedValhall(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_VALHALL_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmReshapedKernelConfig.h b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmReshapedKernelConfig.h new file mode 100644 index 0000000000..c990c89a91 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped/ClGemmReshapedKernelConfig.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_RESHAPED_KERNEL_CONFIGURATION_H +#define ARM_COMPUTE_CL_GEMM_RESHAPED_KERNEL_CONFIGURATION_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" +#include "src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h" +#include "src/core/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** CLGEMMReshaped factory class */ +class ClGemmReshapedKernelConfigurationFactory final +{ +public: + /** Static method to call the CLGEMMReshaped kernel configuration class accordingly with the GPU target + * + * @param[in] gpu GPU target + * + * @return CLGEMMReshaped kernel configuration class + */ + static std::unique_ptr create(GPUTarget gpu) + { + switch(get_arch_from_target(gpu)) + { + case GPUTarget::MIDGARD: + case GPUTarget::BIFROST: + return std::make_unique(gpu); + case GPUTarget::VALHALL: + return std::make_unique(gpu); + default: + ARM_COMPUTE_ERROR("Not supported GPU target"); + } + } +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_RESHAPED_KERNEL_CONFIGURATION_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp new file mode 100644 index 0000000000..7ed6b39f3e --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp @@ -0,0 +1,518 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); + + CLGEMMConfigArray configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + CLGEMMConfigArray configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); + + CLGEMMConfigArray configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G76: + func = configs_G76.get_function(data_type); + break; + case GPUTarget::G51: + func = configs_G51.get_function(data_type); + break; + case GPUTarget::G52: + func = configs_G52.get_function(data_type); + break; + default: + func = configs_G7x.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n <= 2548) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const bool is_workload_big = ((m * n * b) / 16) >= 2048; + + if(m == 1) + { + if(n >= 8192) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + if(n <= 204) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false); + } + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); + if(is_workload_big) + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); + } + } + + // Get lhs_info/rhs_info in case of OpenCL image + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); + if(is_workload_big) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); + } + + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); + const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); + const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + + // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d + const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true; + + if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + { + return std::make_pair(lhs_info_img, rhs_info_img); + } + else + { + return std::make_pair(lhs_info_buf, rhs_info_buf); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(workload <= 274.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n > 2048) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); + + if(r_mk <= 0.0026f) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + else + { + if(r_mk <= 0.0148f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); + + if(workload <= 362.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + if(r_mn <= 22.6067f) + { + if(workload <= 708.8000f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false); + } + } + else + { + if(r_nk <= 0.0917f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + + if(m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(workload <= 7449.60f) + { + if(workload <= 691.60f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); + } + else + { + if(workload <= 4155.20f) + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false); + } + } + } + else + { + if(workload <= 16300.80f) + { + if(r_mn <= 44.56f) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 2), static_cast(128)), static_cast(1)); + if(m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } +} + +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h new file mode 100644 index 0000000000..7b1a1fb04d --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Bifrost based OpenCL GEMMReshapedOnlyRHS configuration */ +class ClGemmDefaultConfigReshapedRhsOnlyBifrost final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_BIFROST_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp new file mode 100644 index 0000000000..4c6e633896 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + CLGEMMConfigArray configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G78: + func = configs_G78.get_function(data_type); + break; + case GPUTarget::G77: + default: + func = configs_G77.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + if(m == 1) + { + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + + if(r_mk <= 0.0064484127797186375) + { + if(r_mn <= 0.0028273810748942196) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const unsigned int h0 = std::max(n / 4, 1U); + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0); + } + } + else + { + if(r_mk <= 0.020312500186264515) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0); + } + } + } + else + { + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + + if(workload <= 1999.2000122070312) + { + if(workload <= 747.1999816894531) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_mn <= 0.03348214365541935) + { + if(r_mk <= 0.028125000186264515) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + if(n <= 836.0) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0); + } + } + else if(m < 128) + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(k >= 512) + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(n >= 64) + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0); + } + else + { + if(k >= 512) + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(m == 1) + { + if(workload <= 278.7000f) + { + if(workload <= 7.5000f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if(r_mn <= 0.0031f) + { + if(workload <= 256.6000f) + { + if(workload <= 16.7500f) + { + if(r_nk <= 1.6671f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + if(r_mk <= 0.0027f) + { + if(r_mk <= 0.0014f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + if(workload <= 8.9500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + } + else + { + if(workload <= 14.1500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 0.0041f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + } + } + } + } + else + { + if(workload <= 363.7000f) + { + if(r_mk <= 0.0031f) + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + } + } + else + { + if(workload <= 1384.8000f) + { + if(workload <= 704.0000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if(workload <= 16761.6006f) + { + if(r_mn <= 187.1250f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if(r_mk <= 432.4630f) + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1); + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(m == 1) + { + if(r_mn <= 0.0038f) + { + if(workload <= 353.9000f) + { + if(workload <= 278.7000f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(r_mk <= 0.0004f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(r_mk <= 0.0030f) + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + } + } + } + else + { + if(r_nk <= 1.9384f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); + } + } + } + else + { + if(r_nk <= 1.0368f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + } + } + else + { + if(workload <= 1422.4000f) + { + if(workload <= 704.0000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(workload <= 1197.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + else + { + if(workload <= 1241.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + } + } + } + else + { + if(workload <= 2769.6000f) + { + if(workload <= 1846.4000f) + { + if(r_mn <= 2.4927f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + else + { + if(r_mn <= 0.6261f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 3.4453f) + { + if(r_mn <= 1.4135f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + } + } + } + else + { + if(r_nk <= 0.0302f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + else + { + if(r_mk <= 181.3750f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + if(workload <= 28035.2002f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 808.6667f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + } + } + } + } + } + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h new file mode 100644 index 0000000000..6a11ddb748 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H +#define ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** Valhall based OpenCL GEMMReshapedOnlyRHS configuration */ +class ClGemmDefaultConfigReshapedRhsOnlyValhall final : public IClGemmKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu); + + // Inherited overridden method + std::pair configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + +private: + std::pair configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_DEFAULT_CONFIG_RESHAPED_RHS_ONLY_VALHALL_H */ diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyBifrost.cpp b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyBifrost.cpp new file mode 100644 index 0000000000..7ed6b39f3e --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyBifrost.cpp @@ -0,0 +1,518 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); + + CLGEMMConfigArray configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + CLGEMMConfigArray configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); + + CLGEMMConfigArray configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G76: + func = configs_G76.get_function(data_type); + break; + case GPUTarget::G51: + func = configs_G51.get_function(data_type); + break; + case GPUTarget::G52: + func = configs_G52.get_function(data_type); + break; + default: + func = configs_G7x.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n <= 2548) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const bool is_workload_big = ((m * n * b) / 16) >= 2048; + + if(m == 1) + { + if(n >= 8192) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + if(n <= 204) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false); + } + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); + if(is_workload_big) + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); + } + } + + // Get lhs_info/rhs_info in case of OpenCL image + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); + if(is_workload_big) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); + } + + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); + const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); + const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + + // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d + const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true; + + if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + { + return std::make_pair(lhs_info_img, rhs_info_img); + } + else + { + return std::make_pair(lhs_info_buf, rhs_info_buf); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(workload <= 274.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n > 2048) + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); + + if(r_mk <= 0.0026f) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + else + { + if(r_mk <= 0.0148f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } + else + { + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); + + if(workload <= 362.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + if(r_mn <= 22.6067f) + { + if(workload <= 708.8000f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false); + } + } + else + { + if(r_nk <= 0.0917f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + + if(m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); + } + else + { + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(workload <= 7449.60f) + { + if(workload <= 691.60f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); + } + else + { + if(workload <= 4155.20f) + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false); + } + } + } + else + { + if(workload <= 16300.80f) + { + if(r_mn <= 44.56f) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int n0 = n < 1280 ? 2 : 4; + const unsigned int h0 = std::max(n / n0, 1U); + return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 4, 1U); + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 2), static_cast(128)), static_cast(1)); + if(m == 1) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); + } +} + +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyValhall.cpp b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyValhall.cpp new file mode 100644 index 0000000000..4c6e633896 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultReshapedRhsOnlyValhall.cpp @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +using namespace arm_compute::misc::shape_calculator; + +ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu) + : IClGemmKernelConfig(gpu) +{ +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + CLGEMMConfigArray configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + CLGEMMConfigArray configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + + switch(_target) + { + case GPUTarget::G78: + func = configs_G78.get_function(data_type); + break; + case GPUTarget::G77: + default: + func = configs_G77.get_function(data_type); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); + return (this->*func)(m, n, k, b); +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + if(m == 1) + { + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + + if(r_mk <= 0.0064484127797186375) + { + if(r_mn <= 0.0028273810748942196) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const unsigned int h0 = std::max(n / 4, 1U); + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0); + } + } + else + { + if(r_mk <= 0.020312500186264515) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0); + } + } + } + else + { + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + + if(workload <= 1999.2000122070312) + { + if(workload <= 747.1999816894531) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_mn <= 0.03348214365541935) + { + if(r_mk <= 0.028125000186264515) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + if(n <= 836.0) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0); + } + } + else if(m < 128) + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(k >= 512) + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + } + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(n >= 64) + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0); + } + else + { + if(k >= 512) + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, 1U); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1); + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(m == 1) + { + if(workload <= 278.7000f) + { + if(workload <= 7.5000f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if(r_mn <= 0.0031f) + { + if(workload <= 256.6000f) + { + if(workload <= 16.7500f) + { + if(r_nk <= 1.6671f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + else + { + if(r_mk <= 0.0027f) + { + if(r_mk <= 0.0014f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + if(workload <= 8.9500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + } + } + else + { + if(workload <= 14.1500f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 0.0041f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); + } + } + } + } + } + } + else + { + if(workload <= 363.7000f) + { + if(r_mk <= 0.0031f) + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); + } + } + } + else + { + if(workload <= 1384.8000f) + { + if(workload <= 704.0000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if(workload <= 16761.6006f) + { + if(r_mn <= 187.1250f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1); + } + } + else + { + if(r_mk <= 432.4630f) + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1); + } + } + } + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(m == 1) + { + if(r_mn <= 0.0038f) + { + if(workload <= 353.9000f) + { + if(workload <= 278.7000f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(r_mk <= 0.0004f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(r_mk <= 0.0030f) + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + } + } + } + else + { + if(r_nk <= 1.9384f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); + } + } + } + else + { + if(r_nk <= 1.0368f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, 0, 0, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); + } + } + } + else + { + if(workload <= 1422.4000f) + { + if(workload <= 704.0000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); + } + else + { + if(workload <= 1197.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + else + { + if(workload <= 1241.6000f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + } + } + } + else + { + if(workload <= 2769.6000f) + { + if(workload <= 1846.4000f) + { + if(r_mn <= 2.4927f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + else + { + if(r_mn <= 0.6261f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 3.4453f) + { + if(r_mn <= 1.4135f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + } + } + } + else + { + if(r_nk <= 0.0302f) + { + return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + } + else + { + if(r_mk <= 181.3750f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + if(workload <= 28035.2002f) + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + else + { + if(r_mk <= 808.6667f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + } + } + } + } + } + } + } +} +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h new file mode 100644 index 0000000000..8fd71276a0 --- /dev/null +++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_RESHAPED_ONLY_RHS_KERNEL_CONFIGURATION_H +#define ARM_COMPUTE_CL_GEMM_RESHAPED_ONLY_RHS_KERNEL_CONFIGURATION_H + +#include "src/core/gpu/cl/kernels/gemm/IClGemmKernelConfig.h" +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h" +#include "src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace gemm +{ +/** CLGEMMReshapedOnlyRHS factory class */ +class ClGemmReshapedOnlyRhsKernelConfigurationFactory final +{ +public: + /** Static method to call the CLGEMMReshapedOnlyRHS kernel configuration class accordingly with the GPU target + * + * @param[in] gpu GPU target + * + * @return CLGEMMReshapedOnlyRHS kernel configuration class + */ + static std::unique_ptr create(GPUTarget gpu) + { + switch(get_arch_from_target(gpu)) + { + case GPUTarget::MIDGARD: + case GPUTarget::BIFROST: + return std::make_unique(gpu); + case GPUTarget::VALHALL: + return std::make_unique(gpu); + default: + ARM_COMPUTE_ERROR("Not supported GPU target"); + } + } +}; +} // namespace gemm +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_RESHAPED_ONLY_RHS_KERNEL_CONFIGURATION_H */ -- cgit v1.2.1