/* * Copyright (c) 2024 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifndef ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H #define ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H #include "arm_compute/core/Utils.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" #include "tests/Globals.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/Validation.h" #include "tests/validation/reference/ScatterLayer.h" #include "tests/SimpleTensor.h" #include #include namespace arm_compute { namespace test { namespace validation { template class ScatterGenericValidationFixture : public framework::Fixture { public: void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo()) { _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo); _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo); } protected: template void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f) { switch(tensor.data_type()) { case DataType::F32: { std::uniform_real_distribution distribution(lo, hi); library->fill(tensor, distribution, i); break; } default: { ARM_COMPUTE_ERROR("Unsupported data type."); } } } // This is used to fill indices tensor with S32 datatype. // Used to prevent ONLY having values that are out of bounds. template void fill_indices(U &&tensor, int i, const TensorShape &shape) { // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension) const int32_t max = std::max({shape[0] , shape[1], shape[2]}); library->fill_tensor_uniform(tensor, i, static_cast(-2), static_cast(max)); } TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. // ---------------------------------------------------- // In order - src, updates, indices, output. TensorType src = create_tensor(shape_a, data_type, 1, a_qinfo); TensorType updates = create_tensor(shape_b, data_type, 1, a_qinfo); TensorType indices = create_tensor(shape_c, DataType::S32, 1, QuantizationInfo()); TensorType dst = create_tensor(out_shape, data_type, 1, o_qinfo); FunctionType scatter; // Configure operator // When scatter_info.zero_initialization is true, pass nullptr to scatter function. if(info.zero_initialization) { scatter.configure(nullptr, &updates, &indices, &dst, info); } else { scatter.configure(&src, &updates, &indices, &dst, info); } // Assertions ARM_COMPUTE_ASSERT(src.info()->is_resizable()); ARM_COMPUTE_ASSERT(updates.info()->is_resizable()); ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); // Allocate tensors src.allocator()->allocate(); updates.allocator()->allocate(); indices.allocator()->allocate(); dst.allocator()->allocate(); ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); ARM_COMPUTE_ASSERT(!updates.info()->is_resizable()); ARM_COMPUTE_ASSERT(!indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Fill update (a) and indices (b) tensors. fill(AccessorType(src), 0); fill(AccessorType(updates), 1); fill_indices(AccessorType(indices), 2, out_shape); scatter.run(); return dst; } SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); TensorShape src_shape = a_shape; TensorShape updates_shape = b_shape; TensorShape indices_shape = c_shape; // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor. if(c_shape.num_dimensions() >= 3) { indices_shape = indices_shape.collapsed_from(1); updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims } // 2. Collapse data dims into a single dim. // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over. src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim. src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim) // Create reference tensors SimpleTensor src{ a_shape, data_type, 1, a_qinfo }; SimpleTensor updates{b_shape, data_type, 1, QuantizationInfo() }; SimpleTensor indices{ c_shape, DataType::S32, 1, QuantizationInfo() }; // Fill reference fill(src, 0); fill(updates, 1); fill_indices(indices, 2, out_shape); // Calculate individual reference. return reference::scatter_layer(src, updates, indices, out_shape, info); } TensorType _target{}; SimpleTensor _reference{}; }; // This fixture will use the same shape for updates as indices. template class ScatterValidationFixture : public ScatterGenericValidationFixture { public: void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) { ScatterGenericValidationFixture::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); } }; } // namespace validation } // namespace test } // namespace arm_compute #endif // ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H