aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-04-24 10:27:13 +0100
committerGunes Bayir <gunes.bayir@arm.com>2024-04-25 15:41:36 +0000
commitada3200f5cec0b6a37f898d5d6f8e69395d7bcb1 (patch)
tree4e0b1e1e5c10a07307c1d2ac9a76c606849629b1 /tests/validation/fixtures/ScatterLayerFixture.h
parent62d600fe8c0afba81cc5f5dd315eb6dcc04f90b8 (diff)
downloadComputeLibrary-ada3200f5cec0b6a37f898d5d6f8e69395d7bcb1.tar.gz
Add update/index/output (m+1)/2d/(m+n) support for CLScatter
Resolves: COMPMID-6894, COMPMID-6896 Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Ramy Elgammal <ramy.elgammal@arm.com> Dynamic-Fusion: Ramy Elgammal <ramy.elgammal@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h82
1 files changed, 64 insertions, 18 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 91e28b58f7..4fb2d7f127 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -29,6 +29,7 @@
#include "tests/Globals.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
#include "tests/SimpleTensor.h"
@@ -46,9 +47,17 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+ QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
- _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+ // this is for improving randomness across tests
+ _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+ + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+ + updates_shape[4] + updates_shape[5]
+ + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
_reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
}
@@ -81,7 +90,9 @@ protected:
library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+ const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+ QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create relevant tensors using ScatterInfo data structure.
// ----------------------------------------------------
@@ -94,14 +105,22 @@ protected:
FunctionType scatter;
// Configure operator
- // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+ // When scatter_info.zero_initialization is true, pass nullptr for src
+ // because dst does not need to be initialized with src values.
if(info.zero_initialization)
{
scatter.configure(nullptr, &updates, &indices, &dst, info);
}
else
{
- scatter.configure(&src, &updates, &indices, &dst, info);
+ if(inplace)
+ {
+ scatter.configure(&src, &updates, &indices, &src, info);
+ }
+ else
+ {
+ scatter.configure(&src, &updates, &indices, &dst, info);
+ }
}
// Assertions
@@ -110,28 +129,51 @@ protected:
ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ add_padding_x({ &src, &updates, &indices});
+
+ if(!inplace)
+ {
+ add_padding_x({ &dst });
+ }
+
// Allocate tensors
src.allocator()->allocate();
updates.allocator()->allocate();
indices.allocator()->allocate();
- dst.allocator()->allocate();
+
+ if(!inplace)
+ {
+ dst.allocator()->allocate();
+ }
ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
- ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+ if(!inplace)
+ {
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+ }
// Fill update (a) and indices (b) tensors.
- fill(AccessorType(src), 0);
- fill(AccessorType(updates), 1);
- fill_indices(AccessorType(indices), 2, out_shape);
+ fill(AccessorType(src), 0 + _hash);
+ fill(AccessorType(updates), 1+ _hash);
+ fill_indices(AccessorType(indices), 2 + _hash, out_shape);
scatter.run();
- return dst;
+
+ if(inplace)
+ {
+ return src;
+ }
+ else
+ {
+ return dst;
+ }
}
- SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
- ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+ const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
@@ -158,9 +200,9 @@ protected:
SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
- fill(src, 0);
- fill(updates, 1);
- fill_indices(indices, 2, out_shape);
+ fill(src, 0 + _hash);
+ fill(updates, 1 + _hash);
+ fill_indices(indices, 2 + _hash, out_shape);
// Calculate individual reference.
return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
@@ -168,6 +210,7 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
+ int32_t _hash{};
};
// This fixture will use the same shape for updates as indices.
@@ -175,9 +218,12 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+ void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
{
- ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+ ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+ indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+ QuantizationInfo(), QuantizationInfo());
}
};