aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorGian Marco <gianmarco.iodice@arm.com>2018-01-12 10:21:40 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:44:21 +0000
commit36a0a4608bf413fc1fd65eb335bfb736ef602149 (patch)
tree2ff0e35dc9e16fedd601b1f24bdc13d25d075b90 /src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
parent46edf63bd630f5e3f3eb31b7d4602caa317da075 (diff)
downloadComputeLibrary-36a0a4608bf413fc1fd65eb335bfb736ef602149.tar.gz
COMPMID-748 - Integrating optimized SGEMM for bifrost
This patch introduces a new GEMM capable to improve the mac utilisation of 10% compared to the GEMM without reshape. However this implementation is not faster in all cases as we need to take into account the time for reshaping the matrices. For this reason an heuristic solution to select the optimal GEMM to use has been added to the function. More information about the heuristic implementation can be found at COMPMID-852. With this new patch, GoogleNet, MobileNet, VGG16 and SqueezeNet can improved the performance of 1.5x. More information about the performance uplift can be found here: https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.02 Change-Id: I024563c06b9aed02a211a974e452bae5c233b04c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/117140 Reviewed-by: Pablo Tello <pablo.tello@arm.com> Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp75
1 files changed, 67 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 19f38bf5a5..e23feb269a 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,24 +36,68 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <set>
#include <string>
using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
namespace
{
using ElementsProcessed = Steps;
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1);
+
if(!is_interleaved_transposed)
{
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
+ }
+ }
+ else
+ {
+ const int m = reshape_info.m();
+ const int n = reshape_info.n();
+ const int k = reshape_info.k();
+ const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
+ const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
+
+ TensorShape tensor_shape0{ input0->tensor_shape() };
+ tensor_shape0.set(0, k);
+ tensor_shape0.set(1, m);
+
+ TensorShape tensor_shape1{ input1->tensor_shape() };
+ tensor_shape1.set(0, n);
+ tensor_shape1.set(1, k);
+
+ const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
+ const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
+
+ const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
+ const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
+ }
}
return Status{};
@@ -122,12 +166,19 @@ CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
{
}
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
+ // Output tensor auto inizialitation if not yet initialized
+ TensorShape tensor_shape{ input0->info()->tensor_shape() };
+ tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1->info()->dimension(0));
+ tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0->info()->dimension(1));
+
+ auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
+
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
_input0 = input0;
_input1 = input1;
@@ -176,7 +227,13 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
std::string kernel_name;
if(is_interleaved_transposed)
{
+ const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
+ const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
+
build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0)));
+ build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width));
+ build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+
if(data_type == DataType::F32)
{
kernel_name = "gemm_mm_interleaved_transposed_f32_" + string_from_target(arch_target);
@@ -230,11 +287,13 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
_config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
}
-Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, GPUTarget gpu_target)
+Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
+ const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target)
{
+ // Note: num_elements_processed will be set in validate_and_configure_window()
ElementsProcessed num_elements_processed{};
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
output->clone().get(),