aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-04-29 17:00:14 +0100
committerGunes Bayir <gunes.bayir@arm.com>2024-04-30 09:33:22 +0000
commit301e33f8f94be6427bf2377570388c379d8c8466 (patch)
tree95c37c7077cd6f2a5a2e7b763365d15112efa2dd /src
parente5ef8c159a14872dda5e36e320f07b0963858d8c (diff)
downloadComputeLibrary-301e33f8f94be6427bf2377570388c379d8c8466.tar.gz
Add fp16 and integer data type support for ScatterNd in Gpu
Resolves: COMPMID-6899 Change-Id: I3743f2c9e5c21e1ec9f4c81d08c148666afad33a Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11505 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Reviewed-by: Sang Won Ha <sangwon.ha@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/common/scatter.cl7
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp5
2 files changed, 11 insertions, 1 deletions
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
index ac9f828df2..e3ec9cc98e 100644
--- a/src/core/CL/cl_kernels/common/scatter.cl
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -28,8 +28,15 @@
// Where a corresponds to the existing value, and b the new value.
#define ADD_OP(a, b) ((a) + (b))
#define SUB_OP(a, b) ((a) - (b))
+
+#ifdef IS_FLOAT
#define MAX_OP(a, b) fmax(a, b)
#define MIN_OP(a, b) fmin(a, b)
+#else // ifdef IS_FLOAT
+#define MAX_OP(a, b) max(a, b)
+#define MIN_OP(a, b) min(a, b)
+#endif // ifdef IS_FLOAT
+
#define UPDATE_OP(a, b) (b)
#ifdef SCATTER_MP1D_2D_MPND
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index 9c25b63c72..21c0253f91 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/DataTypeUtils.h"
#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
#include "src/common/utils/Log.h"
@@ -70,7 +71,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16,
+ DataType::S8, DataType::U32, DataType::U16, DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
ind_shape[1] != upt_shape[upt_dims - 1],
@@ -116,6 +118,7 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context,
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
+ build_opts.add_option_if(is_data_type_float(dst->data_type()), "-DIS_FLOAT");
const int num_dims = dst->num_dimensions();