aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp19
1 files changed, 14 insertions, 5 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 5e02dda9e3..b549638343 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -47,12 +47,14 @@ namespace
{
using ElementsProcessed = Steps;
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((fp_mixed_precision && (input0->data_type() != DataType::F16)), "Mixed precision floating point is supported only for F16 data");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
@@ -216,12 +218,13 @@ CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
{
}
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
_input0 = input0;
_input1 = input1;
@@ -316,6 +319,11 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
// via exhaustive autotuning over a range of representative layer configurations.
set_lws_hint(cl::NDRange(4));
+ if(fp_mixed_precision && data_type == DataType::F16)
+ {
+ // currently wider accumulator is only supported for fp16 kernels.
+ kernel_name += "_acc32";
+ }
}
else // (MIDGARD and F32) or (F16)
{
@@ -331,6 +339,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += (fp_mixed_precision ? "fp_mixed_" : "");
_config_id += (_reinterpret_input_as_3d ? "3di_" : "");
_config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
@@ -347,12 +356,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
}
Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
- const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target)
+ const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
{
// Note: num_elements_processed will be set in validate_and_configure_window()
ElementsProcessed num_elements_processed{};
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
output->clone().get(),