aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h82
1 files changed, 64 insertions, 18 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 91e28b58f7..4fb2d7f127 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -29,6 +29,7 @@
#include "tests/Globals.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
#include "tests/SimpleTensor.h"
@@ -46,9 +47,17 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+ QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
- _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+ // this is for improving randomness across tests
+ _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+ + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+ + updates_shape[4] + updates_shape[5]
+ + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
_reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
}
@@ -81,7 +90,9 @@ protected:
library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+ const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+ QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create relevant tensors using ScatterInfo data structure.
// ----------------------------------------------------
@@ -94,14 +105,22 @@ protected:
FunctionType scatter;
// Configure operator
- // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+ // When scatter_info.zero_initialization is true, pass nullptr for src
+ // because dst does not need to be initialized with src values.
if(info.zero_initialization)
{
scatter.configure(nullptr, &updates, &indices, &dst, info);
}
else
{
- scatter.configure(&src, &updates, &indices, &dst, info);
+ if(inplace)
+ {
+ scatter.configure(&src, &updates, &indices, &src, info);
+ }
+ else
+ {
+ scatter.configure(&src, &updates, &indices, &dst, info);
+ }
}
// Assertions
@@ -110,28 +129,51 @@ protected:
ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ add_padding_x({ &src, &updates, &indices});
+
+ if(!inplace)
+ {
+ add_padding_x({ &dst });
+ }
+
// Allocate tensors
src.allocator()->allocate();
updates.allocator()->allocate();
indices.allocator()->allocate();
- dst.allocator()->allocate();
+
+ if(!inplace)
+ {
+ dst.allocator()->allocate();
+ }
ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
- ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+ if(!inplace)
+ {
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+ }
// Fill update (a) and indices (b) tensors.
- fill(AccessorType(src), 0);
- fill(AccessorType(updates), 1);
- fill_indices(AccessorType(indices), 2, out_shape);
+ fill(AccessorType(src), 0 + _hash);
+ fill(AccessorType(updates), 1+ _hash);
+ fill_indices(AccessorType(indices), 2 + _hash, out_shape);
scatter.run();
- return dst;
+
+ if(inplace)
+ {
+ return src;
+ }
+ else
+ {
+ return dst;
+ }
}
- SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
- ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+ const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
@@ -158,9 +200,9 @@ protected:
SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
- fill(src, 0);
- fill(updates, 1);
- fill_indices(indices, 2, out_shape);
+ fill(src, 0 + _hash);
+ fill(updates, 1 + _hash);
+ fill_indices(indices, 2 + _hash, out_shape);
// Calculate individual reference.
return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
@@ -168,6 +210,7 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
+ int32_t _hash{};
};
// This fixture will use the same shape for updates as indices.
@@ -175,9 +218,12 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+ void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
{
- ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+ ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+ indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+ QuantizationInfo(), QuantizationInfo());
}
};