aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp16
1 files changed, 2 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 2c2a92d070..814cbb631f 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -31,7 +31,6 @@
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
@@ -53,10 +52,8 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_fixed_point(input0->data_type()) && (reshape_info.depth_output_gemm3d() != 1), "GEMM3D only supports floating point data types");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
@@ -95,7 +92,6 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
}
return Status{};
@@ -219,7 +215,6 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
_slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
const DataType data_type = input0->info()->data_type();
- const int fp_pos = input0->info()->fixed_point_position();
// Get target architecture
GPUTarget gpu_target = get_target();
@@ -236,14 +231,11 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Create build options
CLBuildOptions build_opts;
- build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(fp_pos));
// Only define ALPHA when alpha is not 1.0f. This avoids performing unnecessary multiplications.
if(std::abs(1.0f - alpha) > 0.00001f)
{
- build_opts.add_option_if_else(is_data_type_fixed_point(data_type),
- "-DALPHA=" + support::cpp11::to_string((data_type == DataType::QS8 ? sqcvt_qs8_f32(alpha, fp_pos) : sqcvt_qs16_f32(alpha, fp_pos))),
- "-DALPHA=" + float_to_string_with_full_precision(alpha));
+ build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha));
}
build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
@@ -299,10 +291,6 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// via exhaustive autotuning over a range of representative layer configurations.
_lws_hint = cl::NDRange(4);
}
- else if(is_data_type_fixed_point(data_type))
- {
- kernel_name = "gemm_mm_" + lower_string(string_from_data_type(data_type));
- }
else // (MIDGARD and F32) or (F16)
{
kernel_name = "gemm_mm_floating_point";