diff options
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScatterLayerFixture.h | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 5cd9b8115c..af161ef98b 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -48,7 +48,7 @@ class ScatterGenericValidationFixture : public framework::Fixture { public: void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, - TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, + TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, bool padding, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo()) { // this is for improving randomness across tests @@ -57,7 +57,7 @@ public: + updates_shape[4] + updates_shape[5] + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3]; - _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo); + _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, padding, src_qinfo, o_qinfo); _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo); } @@ -104,11 +104,11 @@ protected: { // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension) const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1; - library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max)); + library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(0), static_cast<int32_t>(max)); } TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, - const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, + const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, bool padding, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. @@ -146,11 +146,14 @@ protected: ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); - add_padding_x({ &src, &updates, &indices}); - - if(!inplace) + if(padding) { - add_padding_x({ &dst }); + add_padding_x({ &src, &updates, &indices}); + + if(!inplace) + { + add_padding_x({ &dst }); + } } // Allocate tensors @@ -237,10 +240,10 @@ class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorTy { public: void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, - TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace) + TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace, bool padding) { ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, - indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, + indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, padding, QuantizationInfo(), QuantizationInfo()); } }; |