From 926afe1c8ad6ba6a7bada62a4027fcb79d727104 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 19 Mar 2019 11:44:13 +0000 Subject: COMPMID-2097: Implement a heuristic to dispatch CLGEMMReshapedOnlyRHS kernel from CLGEMM Change-Id: I4170a80647b02501aa669e2c0347ddc39888ee76 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/928 Reviewed-by: Giuseppe Rossini Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/CL/gemm/CLGEMMHelpers.cpp | 55 +++++++ .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 147 ++++++++++++++++++ ...MMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 171 +++++++++++++++++++++ .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 29 +++- 4 files changed, 394 insertions(+), 8 deletions(-) create mode 100644 src/core/CL/gemm/CLGEMMHelpers.cpp create mode 100644 src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp create mode 100644 src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp (limited to 'src/core') diff --git a/src/core/CL/gemm/CLGEMMHelpers.cpp b/src/core/CL/gemm/CLGEMMHelpers.cpp new file mode 100644 index 0000000000..4597d79d43 --- /dev/null +++ b/src/core/CL/gemm/CLGEMMHelpers.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h" + +#include + +namespace arm_compute +{ +namespace cl_gemm +{ +std::pair configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, + bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + + // Configure GEMMLHSMatrixInfo + lhs_info.m0 = m0; + lhs_info.k0 = k0; + lhs_info.v0 = ((m / (lhs_info.m0 * v0)) == 0) ? 1 : v0; + lhs_info.interleave = lhs_interleave; + lhs_info.transpose = lhs_transpose; + + // Configure GEMMRHSMatrixInfo + rhs_info.n0 = n0; + rhs_info.k0 = lhs_info.k0; + rhs_info.h0 = ((n / (rhs_info.n0 * h0)) == 0) ? 1 : h0; + rhs_info.interleave = rhs_interleave; + rhs_info.transpose = rhs_transpose; + + return std::make_pair(lhs_info, rhs_info); +} +} // namespace cl_gemm +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp new file mode 100644 index 0000000000..b791c1cda5 --- /dev/null +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h" +#include "arm_compute/core/GPUTarget.h" + +#include +#include + +namespace arm_compute +{ +namespace cl_gemm +{ +CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifrost(GPUTarget arch) + : ICLGEMMKernelConfiguration(arch) +{ +} + +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); + ARM_COMPUTE_UNUSED(data_type); + + using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + // Configurations for Mali-G76 + static std::map gemm_configs_G76 = + { + { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 }, + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } + }; + + // Configurations for Mali-G7x + static std::map gemm_configs_G7x = + { + { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 }, + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } + }; + + switch(_target) + { + case GPUTarget::G76: + return (this->*gemm_configs_G76[data_type])(m, n, k, b); + default: + return (this->*gemm_configs_G7x[data_type])(m, n, k, b); + } +} + +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true); + } +} + +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true); + } + } + else + { + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true); + } + } +} + +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true); + } +} + +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true); + } +} +} // namespace cl_gemm +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp new file mode 100644 index 0000000000..f696f0b253 --- /dev/null +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h" +#include "arm_compute/core/GPUTarget.h" + +#include +#include + +namespace arm_compute +{ +namespace cl_gemm +{ +CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget arch) + : ICLGEMMKernelConfiguration(arch) +{ +} + +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); + ARM_COMPUTE_UNUSED(data_type); + + using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + // Configurations for Mali-G76 + static std::map gemm_configs_G76 = + { + { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 }, + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 } + }; + + // Configurations for Mali-G7x + static std::map gemm_configs_G7x = + { + { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 }, + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } + }; + + switch(_target) + { + case GPUTarget::G76: + return (this->*gemm_configs_G76[data_type])(m, n, k, b); + default: + return (this->*gemm_configs_G7x[data_type])(m, n, k, b); + } +} + +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + if(n > 2048) + { + const unsigned int h0 = std::max(n / 4, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true); + } +} + +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + } +} + +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(dot8_supported(CLKernelLibrary::get().get_device())) + { + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 4, static_cast(1)); + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + } + else + { + if(m == 1) + { + if(n > 2048) + { + const unsigned int h0 = std::max(n / 4, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); + } + else + { + const unsigned int h0 = std::max(n / 2, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + } + else + { + const unsigned int h0 = std::max(n / 4, static_cast(1)); + return configure_lhs_rhs_info(m, n, 4, 1, 16, 1, h0, false, true, false, true); + } + } +} + +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(m == 1) + { + const unsigned int h0 = std::max(n / 2, static_cast(1)); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true); + } +} +} // namespace cl_gemm +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index af06fecd00..24372657f5 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -68,20 +68,23 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const int n = gemm_info.n(); const int k = gemm_info.k(); - TensorShape tensor_shape0{ input0->tensor_shape() }; - tensor_shape0.set(0, k); - tensor_shape0.set(1, m); - TensorShape tensor_shape1{ input1->tensor_shape() }; tensor_shape1.set(0, n); tensor_shape1.set(1, k); - const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info0); + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != static_cast(k)); + if(gemm_info.reinterpret_input_as_3d()) + { + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) * input0->dimension(2) != static_cast(m)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != static_cast(m)); + } ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1); if(output->total_size() != 0) @@ -99,6 +102,7 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe { unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0); Window win{}; @@ -107,6 +111,10 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. + if(reinterpret_input_as_3d == reinterpret_output_as_3d) + { + reinterpret_output_as_3d = false; + } // Output tensor auto initialization if not yet initialized auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, gemm_info))); @@ -147,7 +155,7 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor - output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape())); + output_access.set_valid_region(win_out, ValidRegion(Coordinates(), output->tensor_shape())); // Collapse along the Z direction // This collapse needs to be here in order to tune the Z dimension of LWS @@ -181,6 +189,11 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. + if(_reinterpret_input_as_3d == _reinterpret_output_as_3d) + { + _reinterpret_input_as_3d = false; + _reinterpret_output_as_3d = false; + } // Check if we need to slide the matrix B const unsigned int num_dimensions_input0 = _input0->info()->num_dimensions(); @@ -204,7 +217,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE"); build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); - build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m())); + build_opts.add_option("-DM=" + support::cpp11::to_string(input0->info()->dimension(1))); build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n())); build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k())); build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); -- cgit v1.2.1