aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-04-29 17:00:14 +0100
committerGunes Bayir <gunes.bayir@arm.com>2024-04-30 09:33:22 +0000
commit301e33f8f94be6427bf2377570388c379d8c8466 (patch)
tree95c37c7077cd6f2a5a2e7b763365d15112efa2dd
parente5ef8c159a14872dda5e36e320f07b0963858d8c (diff)
downloadComputeLibrary-301e33f8f94be6427bf2377570388c379d8c8466.tar.gz
Add fp16 and integer data type support for ScatterNd in Gpu
Resolves: COMPMID-6899 Change-Id: I3743f2c9e5c21e1ec9f4c81d08c148666afad33a Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11505 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Reviewed-by: Sang Won Ha <sangwon.ha@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--docs/user_guide/release_version_and_change_log.dox2
-rw-r--r--src/core/CL/cl_kernels/common/scatter.cl7
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp5
-rw-r--r--tests/datasets/ScatterDataset.h14
-rw-r--r--tests/validation/CL/ScatterLayer.cpp96
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h21
-rw-r--r--tests/validation/reference/ScatterLayer.cpp8
7 files changed, 147 insertions, 6 deletions
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index 952753effb..ca8092797f 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -42,7 +42,7 @@ If there is more than one release in a month then an extra sequential number is
@section S2_2_changelog Changelog
v24.05 Public major release
- - Add @ref CLScatter operator for FP32 data type
+ - Add @ref CLScatter operator for FP32/16, S32/16/8, U32/16/8 data types
v24.04 Public major release
- Add Bfloat16 data type support for @ref NEMatMul.
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
index ac9f828df2..e3ec9cc98e 100644
--- a/src/core/CL/cl_kernels/common/scatter.cl
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -28,8 +28,15 @@
// Where a corresponds to the existing value, and b the new value.
#define ADD_OP(a, b) ((a) + (b))
#define SUB_OP(a, b) ((a) - (b))
+
+#ifdef IS_FLOAT
#define MAX_OP(a, b) fmax(a, b)
#define MIN_OP(a, b) fmin(a, b)
+#else // ifdef IS_FLOAT
+#define MAX_OP(a, b) max(a, b)
+#define MIN_OP(a, b) min(a, b)
+#endif // ifdef IS_FLOAT
+
#define UPDATE_OP(a, b) (b)
#ifdef SCATTER_MP1D_2D_MPND
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index 9c25b63c72..21c0253f91 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/DataTypeUtils.h"
#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
#include "src/common/utils/Log.h"
@@ -70,7 +71,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16,
+ DataType::S8, DataType::U32, DataType::U16, DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
ind_shape[1] != upt_shape[upt_dims - 1],
@@ -116,6 +118,7 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context,
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
+ build_opts.add_option_if(is_data_type_float(dst->data_type()), "-DIS_FLOAT");
const int num_dims = dst->num_dimensions();
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index c0858941db..9dcf859a8f 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -185,6 +185,20 @@ public:
add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(6U, 2U), TensorShape(5U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
}
};
+
+// This dataset is for data types that does not require full testing. It contains selected tests from the above.
+class SmallScatterMixedDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMixedDataset()
+ {
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
+ add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
+ // TODO: add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 4a2462c7d2..2970d82572 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -41,6 +41,8 @@ namespace validation
namespace
{
RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
+RelativeTolerance<float> tolerance_f16(0.02f); /**< Tolerance value for comparing reference's output against implementation's output for fp16 data type */
+RelativeTolerance<int32_t> tolerance_int(0); /**< Tolerance value for comparing reference's output against implementation's output for integer data types */
} // namespace
template <typename T>
@@ -53,6 +55,7 @@ TEST_SUITE(Scatter)
DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
+ TensorInfo(TensorShape(15U), 1, DataType::U8), // Valid
TensorInfo(TensorShape(8U), 1, DataType::F32),
TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
@@ -63,6 +66,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
}),
make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16),
TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(15U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::F32),
TensorInfo(TensorShape(217U), 1, DataType::F32),
TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
@@ -73,6 +77,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
}),
make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
@@ -83,6 +88,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
}),
make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16),
TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(15U), 1, DataType::U8),
TensorInfo(TensorShape(8U), 1, DataType::F32),
TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(271U), 1, DataType::F32),
@@ -93,6 +99,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
}),
make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
ScatterInfo(ScatterFunction::Max, false),
+ ScatterInfo(ScatterFunction::Max, false),
ScatterInfo(ScatterFunction::Min, false),
ScatterInfo(ScatterFunction::Add, false),
ScatterInfo(ScatterFunction::Update, false),
@@ -101,7 +108,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
ScatterInfo(ScatterFunction::Update, false),
ScatterInfo(ScatterFunction::Update, false),
}),
- make("Expected", { false, true, true, false, false, false, false, false, false })),
+ make("Expected", { false, true, true, true, false, false, false, false, false, false })),
input_info, updates_info, indices_info, output_info, scatter_info, expected)
{
const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info);
@@ -168,7 +175,94 @@ FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>
}
TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::F16}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float
+
+TEST_SUITE(Integer)
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int32_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::S32}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // S32
+
+TEST_SUITE(S16)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int16_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::S16}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // S16
+
+TEST_SUITE(S8)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::S8}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // S8
+
+TEST_SUITE(U32)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::U32}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // U32
+
+TEST_SUITE(U16)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::U16}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // U16
+
+TEST_SUITE(U8)
+FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMixedDataset(),
+ make("DataType", {DataType::U8}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_int);
+}
+TEST_SUITE_END() // U8
+TEST_SUITE_END() // Integer
+
TEST_SUITE_END() // Scatter
TEST_SUITE_END() // CL
} // namespace validation
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 4fb2d7f127..35e6b647f3 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -63,13 +63,30 @@ public:
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
+ void fill(U &&tensor, int i)
{
switch(tensor.data_type())
{
case DataType::F32:
+ case DataType::F16:
{
- std::uniform_real_distribution<float> distribution(lo, hi);
+ std::uniform_real_distribution<float> distribution(-10.f, 10.f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::S32:
+ case DataType::S16:
+ case DataType::S8:
+ {
+ std::uniform_int_distribution<int32_t> distribution(-100, 100);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::U32:
+ case DataType::U16:
+ case DataType::U8:
+ {
+ std::uniform_int_distribution<uint32_t> distribution(0, 200);
library->fill(tensor, distribution, i);
break;
}
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp
index 283022e8e2..c9e6035e14 100644
--- a/tests/validation/reference/ScatterLayer.cpp
+++ b/tests/validation/reference/ScatterLayer.cpp
@@ -138,7 +138,13 @@ SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T>
}
template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
-
+template SimpleTensor<half> scatter_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<int32_t> scatter_layer(const SimpleTensor<int32_t> &src, const SimpleTensor<int32_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<uint32_t> scatter_layer(const SimpleTensor<uint32_t> &src, const SimpleTensor<uint32_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<int16_t> scatter_layer(const SimpleTensor<int16_t> &src, const SimpleTensor<int16_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<uint16_t> scatter_layer(const SimpleTensor<uint16_t> &src, const SimpleTensor<uint16_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<int8_t> scatter_layer(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<uint8_t> scatter_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
} // namespace test