aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorVidhya Sudhan Loganathan <vidhyasudhan.loganathan@arm.com>2018-11-16 11:33:12 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-16 17:37:40 +0000
commita25d16c86f0d870408bc8b941aa755093417b0f0 (patch)
treeb62d145a4e5009d894262a7ffa66cdba8260bb03 /src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
parenta7b54f44e2bf133179f24a34007bc93237dd2265 (diff)
downloadComputeLibrary-a25d16c86f0d870408bc8b941aa755093417b0f0.tar.gz
COMPMID-1266 : Add support for FP16 in CLWinogradConvolutionLayer: 5x5 kernels
Introduced F32 accumulation for F16 winograd gemm and output transform WinogradConvolution will be available for F16 only if fast math flag is enabled Change-Id: I215593c205236a0f9669218437bb40b184ec6a4f
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp19
1 files changed, 14 insertions, 5 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 5e02dda9e3..b549638343 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -47,12 +47,14 @@ namespace
{
using ElementsProcessed = Steps;
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((fp_mixed_precision && (input0->data_type() != DataType::F16)), "Mixed precision floating point is supported only for F16 data");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
@@ -216,12 +218,13 @@ CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
{
}
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
_input0 = input0;
_input1 = input1;
@@ -316,6 +319,11 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
// via exhaustive autotuning over a range of representative layer configurations.
set_lws_hint(cl::NDRange(4));
+ if(fp_mixed_precision && data_type == DataType::F16)
+ {
+ // currently wider accumulator is only supported for fp16 kernels.
+ kernel_name += "_acc32";
+ }
}
else // (MIDGARD and F32) or (F16)
{
@@ -331,6 +339,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += (fp_mixed_precision ? "fp_mixed_" : "");
_config_id += (_reinterpret_input_as_3d ? "3di_" : "");
_config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
@@ -347,12 +356,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
}
Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
- const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target)
+ const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
{
// Note: num_elements_processed will be set in validate_and_configure_window()
ElementsProcessed num_elements_processed{};
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
output->clone().get(),