aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-04-29 17:00:14 +0100
committerGunes Bayir <gunes.bayir@arm.com>2024-04-30 09:33:22 +0000
commit301e33f8f94be6427bf2377570388c379d8c8466 (patch)
tree95c37c7077cd6f2a5a2e7b763365d15112efa2dd /tests/validation/fixtures/ScatterLayerFixture.h
parente5ef8c159a14872dda5e36e320f07b0963858d8c (diff)
downloadComputeLibrary-301e33f8f94be6427bf2377570388c379d8c8466.tar.gz
Add fp16 and integer data type support for ScatterNd in Gpu
Resolves: COMPMID-6899 Change-Id: I3743f2c9e5c21e1ec9f4c81d08c148666afad33a Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11505 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Reviewed-by: Sang Won Ha <sangwon.ha@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h21
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 4fb2d7f127..35e6b647f3 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -63,13 +63,30 @@ public:
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
+ void fill(U &&tensor, int i)
{
switch(tensor.data_type())
{
case DataType::F32:
+ case DataType::F16:
{
- std::uniform_real_distribution<float> distribution(lo, hi);
+ std::uniform_real_distribution<float> distribution(-10.f, 10.f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::S32:
+ case DataType::S16:
+ case DataType::S8:
+ {
+ std::uniform_int_distribution<int32_t> distribution(-100, 100);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::U32:
+ case DataType::U16:
+ case DataType::U8:
+ {
+ std::uniform_int_distribution<uint32_t> distribution(0, 200);
library->fill(tensor, distribution, i);
break;
}