aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h38
1 files changed, 21 insertions, 17 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 35e6b647f3..af161ef98b 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -48,7 +48,7 @@ class ScatterGenericValidationFixture : public framework::Fixture
{
public:
void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
- TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+ TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, bool padding,
QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
// this is for improving randomness across tests
@@ -57,7 +57,7 @@ public:
+ updates_shape[4] + updates_shape[5]
+ indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
- _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, padding, src_qinfo, o_qinfo);
_reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
}
@@ -103,12 +103,12 @@ protected:
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
// Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
- const int32_t max = std::max({shape[0] , shape[1], shape[2]});
- library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
+ const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1;
+ library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(0), static_cast<int32_t>(max));
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
- const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+ const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, bool padding,
QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create relevant tensors using ScatterInfo data structure.
@@ -146,11 +146,14 @@ protected:
ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
- add_padding_x({ &src, &updates, &indices});
-
- if(!inplace)
+ if(padding)
{
- add_padding_x({ &dst });
+ add_padding_x({ &src, &updates, &indices});
+
+ if(!inplace)
+ {
+ add_padding_x({ &dst });
+ }
}
// Allocate tensors
@@ -197,12 +200,13 @@ protected:
TensorShape src_shape = a_shape;
TensorShape updates_shape = b_shape;
TensorShape indices_shape = c_shape;
+ const int num_ind_dims = c_shape.num_dimensions();
// 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
- if(c_shape.num_dimensions() >= 3)
+ if(num_ind_dims >= 3)
{
indices_shape = indices_shape.collapsed_from(1);
- updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - (num_ind_dims -1)); // Collapses batch dims
}
// 2. Collapse data dims into a single dim.
@@ -212,16 +216,16 @@ protected:
updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
- SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
- SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
+ SimpleTensor<T> src{ src_shape, data_type, 1, a_qinfo };
+ SimpleTensor<T> updates{updates_shape, data_type, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ indices_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
fill(src, 0 + _hash);
fill(updates, 1 + _hash);
fill_indices(indices, 2 + _hash, out_shape);
- // Calculate individual reference.
+ // Calculate individual reference using collapsed shapes
return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}
@@ -236,10 +240,10 @@ class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorTy
{
public:
void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
- TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
+ TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace, bool padding)
{
ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
- indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+ indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, padding,
QuantizationInfo(), QuantizationInfo());
}
};