From 301e33f8f94be6427bf2377570388c379d8c8466 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 29 Apr 2024 17:00:14 +0100 Subject: Add fp16 and integer data type support for ScatterNd in Gpu Resolves: COMPMID-6899 Change-Id: I3743f2c9e5c21e1ec9f4c81d08c148666afad33a Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11505 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Jakub Sujak Reviewed-by: Sang Won Ha Comments-Addressed: Arm Jenkins --- src/core/CL/cl_kernels/common/scatter.cl | 7 +++++++ src/gpu/cl/kernels/ClScatterKernel.cpp | 5 ++++- 2 files changed, 11 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl index ac9f828df2..e3ec9cc98e 100644 --- a/src/core/CL/cl_kernels/common/scatter.cl +++ b/src/core/CL/cl_kernels/common/scatter.cl @@ -28,8 +28,15 @@ // Where a corresponds to the existing value, and b the new value. #define ADD_OP(a, b) ((a) + (b)) #define SUB_OP(a, b) ((a) - (b)) + +#ifdef IS_FLOAT #define MAX_OP(a, b) fmax(a, b) #define MIN_OP(a, b) fmin(a, b) +#else // ifdef IS_FLOAT +#define MAX_OP(a, b) max(a, b) +#define MIN_OP(a, b) min(a, b) +#endif // ifdef IS_FLOAT + #define UPDATE_OP(a, b) (b) #ifdef SCATTER_MP1D_2D_MPND diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index 9c25b63c72..21c0253f91 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/DataTypeUtils.h" #include "arm_compute/core/utils/helpers/AdjustVecSize.h" #include "src/common/utils/Log.h" @@ -70,7 +71,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16, + DataType::S8, DataType::U32, DataType::U16, DataType::U8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported."); ARM_COMPUTE_RETURN_ERROR_ON_MSG( ind_shape[1] != upt_shape[upt_dims - 1], @@ -116,6 +118,7 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, // Set build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); + build_opts.add_option_if(is_data_type_float(dst->data_type()), "-DIS_FLOAT"); const int num_dims = dst->num_dimensions(); -- cgit v1.2.1