aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h146
1 files changed, 113 insertions, 33 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index bda5532a51..af161ef98b 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -27,8 +27,9 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
#include "tests/Globals.h"
-#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
+#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
#include "tests/SimpleTensor.h"
@@ -46,21 +47,46 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, bool padding,
+ QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
- _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+ // this is for improving randomness across tests
+ _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+ + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+ + updates_shape[4] + updates_shape[5]
+ + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, padding, src_qinfo, o_qinfo);
_reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
}
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
+ void fill(U &&tensor, int i)
{
switch(tensor.data_type())
{
case DataType::F32:
+ case DataType::F16:
+ {
+ std::uniform_real_distribution<float> distribution(-10.f, 10.f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::S32:
+ case DataType::S16:
+ case DataType::S8:
+ {
+ std::uniform_int_distribution<int32_t> distribution(-100, 100);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::U32:
+ case DataType::U16:
+ case DataType::U8:
{
- std::uniform_real_distribution<float> distribution(lo, hi);
+ std::uniform_int_distribution<uint32_t> distribution(0, 200);
library->fill(tensor, distribution, i);
break;
}
@@ -71,37 +97,47 @@ protected:
}
}
- // This is used to fill indices tensor with U32 datatype.
+ // This is used to fill indices tensor with S32 datatype.
// Used to prevent ONLY having values that are out of bounds.
template <typename U>
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
- // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values.
- const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5;
- library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(0), static_cast<uint32_t>(max));
+ // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
+ const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1;
+ library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(0), static_cast<int32_t>(max));
}
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+ const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, bool padding,
+ QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create relevant tensors using ScatterInfo data structure.
// ----------------------------------------------------
// In order - src, updates, indices, output.
TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
- TensorType indices = create_tensor<TensorType>(shape_c, DataType::U32, 1, QuantizationInfo());
+ TensorType indices = create_tensor<TensorType>(shape_c, DataType::S32, 1, QuantizationInfo());
TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
FunctionType scatter;
// Configure operator
- // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+ // When scatter_info.zero_initialization is true, pass nullptr for src
+ // because dst does not need to be initialized with src values.
if(info.zero_initialization)
{
scatter.configure(nullptr, &updates, &indices, &dst, info);
}
else
{
- scatter.configure(&src, &updates, &indices, &dst, info);
+ if(inplace)
+ {
+ scatter.configure(&src, &updates, &indices, &src, info);
+ }
+ else
+ {
+ scatter.configure(&src, &updates, &indices, &dst, info);
+ }
}
// Assertions
@@ -110,51 +146,92 @@ protected:
ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ if(padding)
+ {
+ add_padding_x({ &src, &updates, &indices});
+
+ if(!inplace)
+ {
+ add_padding_x({ &dst });
+ }
+ }
+
// Allocate tensors
src.allocator()->allocate();
updates.allocator()->allocate();
indices.allocator()->allocate();
- dst.allocator()->allocate();
+
+ if(!inplace)
+ {
+ dst.allocator()->allocate();
+ }
ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
- ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+ if(!inplace)
+ {
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+ }
// Fill update (a) and indices (b) tensors.
- fill(AccessorType(src), 0);
- fill(AccessorType(updates), 1);
- fill_indices(AccessorType(indices), 2, out_shape);
+ fill(AccessorType(src), 0 + _hash);
+ fill(AccessorType(updates), 1+ _hash);
+ fill_indices(AccessorType(indices), 2 + _hash, out_shape);
scatter.run();
- return dst;
+ if(inplace)
+ {
+ return src;
+ }
+ else
+ {
+ return dst;
+ }
}
- SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
- ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+ const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
+ TensorShape src_shape = a_shape;
+ TensorShape updates_shape = b_shape;
+ TensorShape indices_shape = c_shape;
+ const int num_ind_dims = c_shape.num_dimensions();
+
+ // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
+ if(num_ind_dims >= 3)
+ {
+ indices_shape = indices_shape.collapsed_from(1);
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - (num_ind_dims -1)); // Collapses batch dims
+ }
+
+ // 2. Collapse data dims into a single dim.
+ // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over.
+ src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim.
+ src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim
+ updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
- SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
- SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<uint32_t> indices{ c_shape, DataType::U32, 1, QuantizationInfo() };
+ SimpleTensor<T> src{ src_shape, data_type, 1, a_qinfo };
+ SimpleTensor<T> updates{updates_shape, data_type, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ indices_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
- fill(src, 0);
- fill(updates, 1);
- fill_indices(indices, 2, out_shape);
-
- // Calculate individual reference.
- auto result = reference::scatter_layer<T>(src, updates, indices, out_shape, info);
+ fill(src, 0 + _hash);
+ fill(updates, 1 + _hash);
+ fill_indices(indices, 2 + _hash, out_shape);
- return result;
+ // Calculate individual reference using collapsed shapes
+ return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}
TensorType _target{};
SimpleTensor<T> _reference{};
+ int32_t _hash{};
};
// This fixture will use the same shape for updates as indices.
@@ -162,9 +239,12 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+ void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace, bool padding)
{
- ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+ ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+ indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, padding,
+ QuantizationInfo(), QuantizationInfo());
}
};