aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2024-03-25 15:55:42 +0000
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-04-22 14:44:09 +0000
commit7377107378d6c26439320fce78a551e85b5ad36a (patch)
tree3aa9c74c59993f9d51924fc123eefa17e3376a79 /tests
parent5057ce9e1866ffa0388543d81af32083b5b1c684 (diff)
downloadComputeLibrary-7377107378d6c26439320fce78a551e85b5ad36a.tar.gz
Scatter GPU Kernel Implementation for 1D tensors.
Resolves: [COMPMID-6891, COMPMID-6892] Change-Id: I5b094fff1bff4c4c59cc44f7d6beab0e40133d8e Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11394 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/datasets/ScatterDataset.h4
-rw-r--r--tests/validation/CL/ScatterLayer.cpp38
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h19
-rw-r--r--tests/validation/reference/ScatterLayer.cpp10
-rw-r--r--tests/validation/reference/ScatterLayer.h4
5 files changed, 38 insertions, 37 deletions
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index d204d17855..f7547ecc94 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -118,8 +118,8 @@ class Small1DScatterDataset final : public ScatterDataset
public:
Small1DScatterDataset()
{
- add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U));
- add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U));
+ add_config(TensorShape(6U), TensorShape(6U), TensorShape(1U, 6U), TensorShape(6U));
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
}
};
} // namespace datasets
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 56338f489f..9711671841 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -38,6 +38,10 @@ namespace test
{
namespace validation
{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
+} // namespace
template <typename T>
using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
@@ -46,7 +50,7 @@ using framework::dataset::make;
TEST_SUITE(CL)
TEST_SUITE(Scatter)
-DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
+DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
TensorInfo(TensorShape(8U), 1, DataType::F32),
@@ -61,12 +65,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(2U), 1, DataType::F32),
}),
- make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
- TensorInfo(TensorShape(15U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1 , DataType::S32)
+ make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32)
}),
make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
TensorInfo(TensorShape(15U), 1, DataType::F32),
@@ -76,27 +80,27 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
TensorInfo(TensorShape(12U), 1, DataType::F32)
}),
make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
+ ScatterInfo(ScatterFunction::Max, false),
+ ScatterInfo(ScatterFunction::Min, false),
+ ScatterInfo(ScatterFunction::Add, false),
+ ScatterInfo(ScatterFunction::Update, false),
+ ScatterInfo(ScatterFunction::Sub, false),
}),
make("Expected", { false, true, true, false, false, false })),
input_info, updates_info, indices_info, output_info, scatter_info, expected)
{
- // TODO: Enable validation tests.
- ARM_COMPUTE_UNUSED(input_info);
- ARM_COMPUTE_UNUSED(updates_info);
- ARM_COMPUTE_UNUSED(indices_info);
- ARM_COMPUTE_UNUSED(output_info);
- ARM_COMPUTE_UNUSED(scatter_info);
- ARM_COMPUTE_UNUSED(expected);
+ const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
make("DataType", {DataType::F32}),
- make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }),
make("ZeroInit", {false})))
{
- // TODO: Add validate() here.
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
// With this test, src should be passed as nullptr.
@@ -105,7 +109,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework
make("ScatterFunction", {ScatterFunction::Add}),
make("ZeroInit", {true})))
{
- // TODO: Add validate() here
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index bda5532a51..451a1e1416 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -27,7 +27,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
#include "tests/Globals.h"
-#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
+#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
@@ -71,14 +71,14 @@ protected:
}
}
- // This is used to fill indices tensor with U32 datatype.
+ // This is used to fill indices tensor with S32 datatype.
// Used to prevent ONLY having values that are out of bounds.
template <typename U>
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
- // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values.
- const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5;
- library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(0), static_cast<uint32_t>(max));
+ // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
+ const int32_t max = std::max({shape[0] , shape[1], shape[2]});
+ library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
@@ -88,7 +88,7 @@ protected:
// In order - src, updates, indices, output.
TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
- TensorType indices = create_tensor<TensorType>(shape_c, DataType::U32, 1, QuantizationInfo());
+ TensorType indices = create_tensor<TensorType>(shape_c, DataType::S32, 1, QuantizationInfo());
TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
FunctionType scatter;
@@ -127,7 +127,6 @@ protected:
fill_indices(AccessorType(indices), 2, out_shape);
scatter.run();
-
return dst;
}
@@ -140,7 +139,7 @@ protected:
// Create reference tensors
SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<uint32_t> indices{ c_shape, DataType::U32, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
fill(src, 0);
@@ -148,9 +147,7 @@ protected:
fill_indices(indices, 2, out_shape);
// Calculate individual reference.
- auto result = reference::scatter_layer<T>(src, updates, indices, out_shape, info);
-
- return result;
+ return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}
TensorType _target{};
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp
index 920f2b9990..7543b46bb1 100644
--- a/tests/validation/reference/ScatterLayer.cpp
+++ b/tests/validation/reference/ScatterLayer.cpp
@@ -66,7 +66,7 @@ template float reduce_op(const float &current,const float &update,const ScatterF
// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors.
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
@@ -84,14 +84,14 @@ SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleT
}
// 2. Get max index of output tensor, then iterate over index tensor.
- const auto x_bound = dst.shape().x();
+ const int x_bound = static_cast<int>(dst.shape().x());
for(int i = 0; i < indices.num_elements(); ++i)
{
// 3. Check whether index is out of bounds for dst, if not then apply reduce op.
const auto index = indices[i];
- if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned.
+ if (index < x_bound && index >= 0) // Note : we ignore negative index values.
{
dst[index] = reduce_op(dst[index], updates[i], info.func);
}
@@ -100,12 +100,12 @@ SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleT
}
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
}
-template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/ScatterLayer.h b/tests/validation/reference/ScatterLayer.h
index dc441a8894..97d5e70b0d 100644
--- a/tests/validation/reference/ScatterLayer.h
+++ b/tests/validation/reference/ScatterLayer.h
@@ -37,10 +37,10 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
} // namespace test