From eb65f6da695ac0d3e495817145cceb1c4de4f048 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Wed, 15 Apr 2020 11:42:15 +0100 Subject: COMPMID-3304: Update OpenCL GEMM heuristic for Int8 Change-Id: I6b7ff678d8d0437a1639db2ff602ea1cdb155464 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3056 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../CLGEMMNativeKernelConfigurationBifrost.cpp | 45 ++++++++++--- .../CLGEMMNativeKernelConfigurationMidgard.cpp | 75 ++++++++++++++++++++++ .../CLGEMMNativeKernelConfigurationValhall.cpp | 7 +- .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 10 ++- .../CLGEMMReshapedKernelConfigurationValhall.cpp | 5 +- ...MMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 20 ++++-- ...MMReshapedOnlyRHSKernelConfigurationValhall.cpp | 15 ++++- 7 files changed, 155 insertions(+), 22 deletions(-) create mode 100644 src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp (limited to 'src/core/CL/gemm') diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp index c4a9ccd703..c6b51c698a 100644 --- a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp @@ -42,9 +42,6 @@ CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(G std::pair CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); - ARM_COMPUTE_UNUSED(data_type); - using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); @@ -52,31 +49,61 @@ std::pair CLGEMMNativeKernelConfigurationB static std::map gemm_configs_G71 = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 } }; // Configurations for Mali-G76 static std::map gemm_configs_G76 = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 } }; // Default configurations static std::map gemm_configs_default = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 } }; switch(_target) { case GPUTarget::G71: - return (this->*gemm_configs_G71[data_type])(m, n, k, b); + if(gemm_configs_G71.find(data_type) != gemm_configs_G71.end()) + { + return (this->*gemm_configs_G71[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } case GPUTarget::G76: - return (this->*gemm_configs_G76[data_type])(m, n, k, b); + if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end()) + { + return (this->*gemm_configs_G76[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } default: - return (this->*gemm_configs_default[data_type])(m, n, k, b); + if(gemm_configs_default.find(data_type) != gemm_configs_default.end()) + { + return (this->*gemm_configs_default[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } } } diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp new file mode 100644 index 0000000000..86c056ffc2 --- /dev/null +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h" +#include "arm_compute/core/GPUTarget.h" + +#include +#include + +namespace arm_compute +{ +namespace cl_gemm +{ +CLGEMMNativeKernelConfigurationMidgard::CLGEMMNativeKernelConfigurationMidgard(GPUTarget gpu) + : ICLGEMMKernelConfiguration(gpu) +{ +} + +std::pair CLGEMMNativeKernelConfigurationMidgard::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationMidgard::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + // Configurations for Midgard architectures + static std::map default_configs = + { + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationMidgard::default_q8 } + }; + + if(default_configs.find(data_type) != default_configs.end()) + { + return (this->*default_configs[data_type])(m, n, k, b); + } + ARM_COMPUTE_ERROR("Not supported data type"); +} + +std::pair CLGEMMNativeKernelConfigurationMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + const unsigned int m0 = std::min(m, static_cast(4)); + const unsigned int n0 = std::min(n, static_cast(4)); + + return configure_lhs_rhs_info(m, n, m0, n0, 2, 1, 1, false, false, false, false); +} +} // namespace cl_gemm +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp index 7cf0f0e1a8..c25cdac81a 100644 --- a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp @@ -45,12 +45,15 @@ std::pair CLGEMMNativeKernelConfigurationV using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - // Configurations for Mali-G71 + // Configurations for Mali-G77 static std::map gemm_configs_G77 = { { DataType::F32, &CLGEMMNativeKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMNativeKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index 144c23a798..990cc72eb0 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -49,7 +49,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } }; // Configurations for Mali-G7x @@ -57,7 +60,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp index 20fa3d65bf..b96dc96e87 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp @@ -49,7 +49,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp index 8e798116bf..8826cca11b 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -50,7 +50,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } }; // Configurations for Mali-G76 @@ -58,7 +61,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 } }; // Configurations for Mali-G7x @@ -66,7 +72,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } }; switch(_target) @@ -235,15 +244,14 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi } else { + const int h0 = std::max(std::min(static_cast(n / 2), static_cast(128)), static_cast(1)); if(m == 1) { - const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); } else { - const unsigned int h0 = std::max(n / 4, 1U); - return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, h0, false, true, false, true); + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); } } } diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp index 951447e1a0..783d0fe91b 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp @@ -50,7 +50,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) @@ -135,7 +138,15 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi } else { - return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 4, false, true, false, true); + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, true); + } } } } // namespace cl_gemm -- cgit v1.2.1