From 8609ca08556d4d928e290b963c71e731ac24bd52 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Thu, 29 Feb 2024 17:00:07 +0000 Subject: Add skeleton for CLScatter op, reference and tests - Adds dataset for tests - Adds skeleton for function, operator, reference and tests Resolves: [COMPMID-6889] Signed-off-by: Mohammed Suhail Munshi Change-Id: I7e57e8b4577fef6aa7421e672894c249cad6c5fa Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11234 Comments-Addressed: Arm Jenkins Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- Android.bp | 3 + arm_compute/core/Types.h | 2 +- arm_compute/function_info/ScatterInfo.h | 54 +++++++++ arm_compute/runtime/CL/CLFunctions.h | 3 +- arm_compute/runtime/CL/functions/CLScatter.h | 109 +++++++++++++++++ filelist.json | 9 ++ src/gpu/cl/kernels/ClScatterKernel.cpp | 78 ++++++++++++ src/gpu/cl/kernels/ClScatterKernel.h | 79 ++++++++++++ src/gpu/cl/operators/ClScatter.cpp | 93 ++++++++++++++ src/gpu/cl/operators/ClScatter.h | 96 +++++++++++++++ src/runtime/CL/functions/CLScatter.cpp | 86 +++++++++++++ tests/CMakeLists.txt | 3 +- tests/datasets/ScatterDataset.h | 127 ++++++++++++++++++++ tests/validation/CL/ScatterLayer.cpp | 57 +++++++++ tests/validation/fixtures/ScatterLayerFixture.h | 153 ++++++++++++++++++++++++ tests/validation/reference/ScatterLayer.cpp | 59 +++++++++ tests/validation/reference/ScatterLayer.h | 48 ++++++++ utils/TypePrinter.h | 72 +++++++++++ 18 files changed, 1128 insertions(+), 3 deletions(-) create mode 100644 arm_compute/function_info/ScatterInfo.h create mode 100644 arm_compute/runtime/CL/functions/CLScatter.h create mode 100644 src/gpu/cl/kernels/ClScatterKernel.cpp create mode 100644 src/gpu/cl/kernels/ClScatterKernel.h create mode 100644 src/gpu/cl/operators/ClScatter.cpp create mode 100644 src/gpu/cl/operators/ClScatter.h create mode 100644 src/runtime/CL/functions/CLScatter.cpp create mode 100644 tests/datasets/ScatterDataset.h create mode 100644 tests/validation/CL/ScatterLayer.cpp create mode 100644 tests/validation/fixtures/ScatterLayerFixture.h create mode 100644 tests/validation/reference/ScatterLayer.cpp create mode 100644 tests/validation/reference/ScatterLayer.h diff --git a/Android.bp b/Android.bp index d216c6785d..bb0486403b 100644 --- a/Android.bp +++ b/Android.bp @@ -707,6 +707,7 @@ cc_library_static { "src/gpu/cl/kernels/ClQuantizeKernel.cpp", "src/gpu/cl/kernels/ClReshapeKernel.cpp", "src/gpu/cl/kernels/ClScaleKernel.cpp", + "src/gpu/cl/kernels/ClScatterKernel.cpp", "src/gpu/cl/kernels/ClSoftmaxKernel.cpp", "src/gpu/cl/kernels/ClTransposeKernel.cpp", "src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp", @@ -758,6 +759,7 @@ cc_library_static { "src/gpu/cl/operators/ClQuantize.cpp", "src/gpu/cl/operators/ClReshape.cpp", "src/gpu/cl/operators/ClScale.cpp", + "src/gpu/cl/operators/ClScatter.cpp", "src/gpu/cl/operators/ClSoftmax.cpp", "src/gpu/cl/operators/ClSub.cpp", "src/gpu/cl/operators/ClTranspose.cpp", @@ -856,6 +858,7 @@ cc_library_static { "src/runtime/CL/functions/CLReshapeLayer.cpp", "src/runtime/CL/functions/CLReverse.cpp", "src/runtime/CL/functions/CLScale.cpp", + "src/runtime/CL/functions/CLScatter.cpp", "src/runtime/CL/functions/CLSelect.cpp", "src/runtime/CL/functions/CLSlice.cpp", "src/runtime/CL/functions/CLSoftmaxLayer.cpp", diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 6b51af17d4..f2f60c150e 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2023 Arm Limited. + * Copyright (c) 2016-2024 Arm Limited. * * SPDX-License-Identifier: MIT * diff --git a/arm_compute/function_info/ScatterInfo.h b/arm_compute/function_info/ScatterInfo.h new file mode 100644 index 0000000000..176a863ac5 --- /dev/null +++ b/arm_compute/function_info/ScatterInfo.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H +#define ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H + +#include "arm_compute/core/Error.h" + +namespace arm_compute +{ +/** Scatter Function */ +enum class ScatterFunction +{ + Update = 0, + Add = 1, + Sub = 2, + Max = 3, + Min = 4 +}; +/** Scatter operator information */ +struct ScatterInfo +{ + ScatterInfo(ScatterFunction f, bool zero) : func(f), zero_initialization(zero) + { + ARM_COMPUTE_ERROR_ON_MSG(f != ScatterFunction::Add && zero, + "Zero initialisation is only supported with Add Scatter Function."); + } + ScatterFunction func; /**< Type of scatter function to use with scatter operator*/ + bool zero_initialization{false}; /**< Fill output tensors with 0. Only available with add scatter function. */ +}; +} // namespace arm_compute + +#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H diff --git a/arm_compute/runtime/CL/CLFunctions.h b/arm_compute/runtime/CL/CLFunctions.h index cf757239cb..a09ca551d2 100644 --- a/arm_compute/runtime/CL/CLFunctions.h +++ b/arm_compute/runtime/CL/CLFunctions.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2023 Arm Limited. + * Copyright (c) 2016-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -101,6 +101,7 @@ #include "arm_compute/runtime/CL/functions/CLROIAlignLayer.h" #include "arm_compute/runtime/CL/functions/CLROIPoolingLayer.h" #include "arm_compute/runtime/CL/functions/CLScale.h" +#include "arm_compute/runtime/CL/functions/CLScatter.h" #include "arm_compute/runtime/CL/functions/CLSelect.h" #include "arm_compute/runtime/CL/functions/CLSlice.h" #include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h" diff --git a/arm_compute/runtime/CL/functions/CLScatter.h b/arm_compute/runtime/CL/functions/CLScatter.h new file mode 100644 index 0000000000..1c90d208bd --- /dev/null +++ b/arm_compute/runtime/CL/functions/CLScatter.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H +#define ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H + +#include "arm_compute/core/Error.h" +#include "arm_compute/runtime/IFunction.h" + +#include + +namespace arm_compute +{ +class ICLTensor; +class ITensorInfo; +struct ScatterInfo; +class CLCompileContext; + +/** Function to compute ScatterND Layer */ +class CLScatter : public IFunction +{ +public: + /** Default Constructor */ + CLScatter(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLScatter(const CLScatter &) = delete; + /** Default move constructor */ + CLScatter(CLScatter &&); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLScatter &operator=(const CLScatter &) = delete; + /** Default move assignment operator */ + CLScatter &operator=(CLScatter &&); + /** Default destructor */ + ~CLScatter(); + /** Initialise the kernel's inputs and outputs + * + * Valid data layouts: + * - All + * + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true. + * @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src + * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32 + * @param[out] output Destination tensor. Data types supported: same as @p src. + * @param[in] info Scatter info object. + */ + void configure(const CLCompileContext &compile_context, + const ICLTensor *src, + const ICLTensor *updates, + const ICLTensor *indices, + ICLTensor *output, + const ScatterInfo &info); + /** Initialise the kernel's inputs and output + * + * Similar to @ref CLScatter::configure() + */ + void configure(const ICLTensor *src, + const ICLTensor *updates, + const ICLTensor *indices, + ICLTensor *output, + const ScatterInfo &info); + /** Static function to check if given info will lead to a valid configuration of @ref CLScatter + * + * @param[in] src Source tensor. + * @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src + * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32 + * @param[in] output Destination tensor. Data types supported: same as @p src. + * @param[in] info Scatter info containing type of scatter. + * + * @return a status + */ + static Status validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *output, + const ScatterInfo &info); + + // Inherited methods overridden: + void run() override; + +private: + struct Impl; + std::unique_ptr _impl; +}; +} // namespace arm_compute + +#endif // ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H diff --git a/filelist.json b/filelist.json index 9f0f302033..ab7f16bc90 100644 --- a/filelist.json +++ b/filelist.json @@ -770,6 +770,15 @@ ] } }, + "Scatter": { + "files": { + "common": [ + "src/gpu/cl/kernels/ClScatterKernel.cpp", + "src/gpu/cl/operators/ClScatter.cpp", + "src/runtime/CL/functions/CLScatter.cpp" + ] + } + }, "Select": { "files": { "common": [ diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp new file mode 100644 index 0000000000..720164366e --- /dev/null +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/ClScatterKernel.h" + +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/ITensorPack.h" +#include "arm_compute/core/TensorInfo.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +ClScatterKernel::ClScatterKernel() +{ +} + +Status ClScatterKernel::validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *dst, + const ScatterInfo &info) +{ + ARM_COMPUTE_UNUSED(src); + ARM_COMPUTE_UNUSED(updates); + ARM_COMPUTE_UNUSED(indices); + ARM_COMPUTE_UNUSED(dst); + ARM_COMPUTE_UNUSED(info); + + return Status{}; +} +void ClScatterKernel::configure(const ClCompileContext &compile_context, + const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + ITensorInfo *dst, + const ScatterInfo &info) +{ + ARM_COMPUTE_UNUSED(compile_context); + ARM_COMPUTE_UNUSED(src); + ARM_COMPUTE_UNUSED(updates); + ARM_COMPUTE_UNUSED(indices); + ARM_COMPUTE_UNUSED(dst); + ARM_COMPUTE_UNUSED(info); +} + +void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) +{ + ARM_COMPUTE_UNUSED(tensors); + ARM_COMPUTE_UNUSED(window); + ARM_COMPUTE_UNUSED(queue); +} + +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h new file mode 100644 index 0000000000..dda614ff3e --- /dev/null +++ b/src/gpu/cl/kernels/ClScatterKernel.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H +#define ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H + +#include "arm_compute/function_info/ScatterInfo.h" + +#include "src/core/common/Macros.h" +#include "src/gpu/cl/ClCompileContext.h" +#include "src/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +class ClScatterKernel : public IClKernel +{ +public: + ClScatterKernel(); + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel); + /** Initialise the kernel's input and output. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Input tensor info for the source matrix. + * @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src + * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32. + * @param[out] dst Output tensor info. Data type supported: same as @p src + * @param[in] info Attributes for Scatter Kernel + */ + void configure(const ClCompileContext &compile_context, + const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + ITensorInfo *dst, + const ScatterInfo &info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClScatterKernel::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *dst, + const ScatterInfo &info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute + +#endif // ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp new file mode 100644 index 0000000000..74d747bc16 --- /dev/null +++ b/src/gpu/cl/operators/ClScatter.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/operators/ClScatter.h" + +#include "arm_compute/core/Error.h" +#include "arm_compute/runtime/CL/CLScheduler.h" + +#include "src/common/utils/Log.h" +#include "src/gpu/cl/kernels/ClFillKernel.h" +#include "src/gpu/cl/kernels/ClScatterKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +using namespace arm_compute::opencl::kernels; + +ClScatter::ClScatter() +{ +} + +Status ClScatter::validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *dst, + const ScatterInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32); + + return kernels::ClScatterKernel::validate(src, updates, indices, dst, info); +} + +void ClScatter::configure(const CLCompileContext &compile_context, + const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + ITensorInfo *dst, + const ScatterInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, indices, dst); + ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info); + ARM_COMPUTE_UNUSED(src); + ARM_COMPUTE_UNUSED(updates); + ARM_COMPUTE_UNUSED(indices); + ARM_COMPUTE_UNUSED(dst); + ARM_COMPUTE_UNUSED(info); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info)); + _fill_zero = info.zero_initialization; + + // If necessary, create fill kernel to fill dst tensor. + if (_fill_zero) + { + _fill_kernel = std::make_unique(); + } + + // Configure ClScatterKernel + auto k = std::make_unique(); + k->set_target(CLScheduler::get().target()); + k->configure(compile_context, src, updates, indices, dst, info); + _scatter_kernel = std::move(k); +} + +void ClScatter::run(ITensorPack &tensors) +{ + ARM_COMPUTE_UNUSED(tensors); +} + +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h new file mode 100644 index 0000000000..433f7ca3a4 --- /dev/null +++ b/src/gpu/cl/operators/ClScatter.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H +#define ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H + +#include "arm_compute/function_info/ScatterInfo.h" + +#include "src/gpu/cl/IClKernel.h" +#include "src/gpu/cl/IClOperator.h" + +#include + +namespace arm_compute +{ +namespace opencl +{ +// Forward declaration +class ClFillKernel; +class ClScatterKernel; + +/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels: + * + * -# @ref kernels::ClScatterKernel + */ +class ClScatter : public IClOperator +{ +public: + /** Constructor */ + ClScatter(); + /** Default destructor */ + ~ClScatter() = default; + /** Initialise the kernel's inputs and output + * + * Valid data layouts: + * - All + * + * @note indices must always be U32 + * @note src, updates and dst tensors must be same datatype. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization. + * @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src. + * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only. + * @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates. + * @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo. + */ + void configure(const CLCompileContext &compile_context, + const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + ITensorInfo *dst, + const ScatterInfo &Scatter_info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClScatter::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *dst, + const ScatterInfo &Scatter_info); + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + +private: + std::unique_ptr _scatter_kernel{nullptr}; + std::unique_ptr _fill_kernel{nullptr}; + bool _fill_zero{false}; +}; +} // namespace opencl +} // namespace arm_compute +#endif // ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H diff --git a/src/runtime/CL/functions/CLScatter.cpp b/src/runtime/CL/functions/CLScatter.cpp new file mode 100644 index 0000000000..e1de92968a --- /dev/null +++ b/src/runtime/CL/functions/CLScatter.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm_compute/runtime/CL/functions/CLScatter.h" + +#include "arm_compute/function_info/ScatterInfo.h" +#include "arm_compute/runtime/CL/CLTensor.h" + +#include "src/gpu/cl/operators/ClScatter.h" + +namespace arm_compute +{ +using OperatorType = opencl::ClScatter; + +struct CLScatter::Impl +{ + std::unique_ptr op{nullptr}; + ITensorPack run_pack{}; +}; + +CLScatter::CLScatter() : _impl(std::make_unique()) +{ +} + +CLScatter::~CLScatter() = default; + +void CLScatter::configure(const ICLTensor *src, + const ICLTensor *updates, + const ICLTensor *indices, + ICLTensor *output, + const ScatterInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + configure(CLKernelLibrary::get().get_compile_context(), src, updates, indices, output, info); +} + +void CLScatter::configure(const CLCompileContext &compile_context, + const ICLTensor *src, + const ICLTensor *updates, + const ICLTensor *indices, + ICLTensor *output, + const ScatterInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, indices, output); + + _impl->op = std::make_unique(); + _impl->op->configure(compile_context, src->info(), updates->info(), indices->info(), output->info(), info); + _impl->run_pack = {{ACL_SRC_0, src}, {ACL_SRC_1, updates}, {ACL_SRC_2, indices}, {ACL_DST, output}}; +} + +Status CLScatter::validate(const ITensorInfo *src, + const ITensorInfo *updates, + const ITensorInfo *indices, + const ITensorInfo *output, + const ScatterInfo &info) +{ + return OperatorType::validate(src, updates, indices, output, info); +} + +void CLScatter::run() +{ + _impl->op->run(_impl->run_pack); +} + +} // namespace arm_compute diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3f2223596f..20a010f38c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Arm Limited. +# Copyright (c) 2023-2024 Arm Limited. # # SPDX-License-Identifier: MIT # @@ -100,6 +100,7 @@ target_sources( validation/reference/Floor.cpp validation/reference/PriorBoxLayer.cpp validation/reference/Scale.cpp + validation/reference/ScatterLayer.cpp validation/reference/ReorgLayer.cpp validation/reference/Range.cpp validation/reference/ArithmeticDivision.cpp diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h new file mode 100644 index 0000000000..09f6338432 --- /dev/null +++ b/tests/datasets/ScatterDataset.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_TESTS_DATASETS_SCATTERDATASET_H +#define ACL_TESTS_DATASETS_SCATTERDATASET_H + +#include "arm_compute/core/TensorShape.h" +#include "utils/TypePrinter.h" + +namespace arm_compute +{ +namespace test +{ +namespace datasets +{ + +class ScatterDataset +{ +public: + using type = std::tuple; + + struct iterator + { + iterator(std::vector::const_iterator src_it, + std::vector::const_iterator updates_it, + std::vector::const_iterator indices_it, + std::vector::const_iterator dst_it) + : _src_it{ std::move(src_it) }, + _updates_it{ std::move(updates_it) }, + _indices_it{std::move(indices_it)}, + _dst_it{ std::move(dst_it) } + { + } + + std::string description() const + { + std::stringstream description; + description << "A=" << *_src_it << ":"; + description << "B=" << *_updates_it << ":"; + description << "C=" << *_indices_it << ":"; + description << "Out=" << *_dst_it << ":"; + return description.str(); + } + + ScatterDataset::type operator*() const + { + return std::make_tuple(*_src_it, *_updates_it, *_indices_it, *_dst_it); + } + + iterator &operator++() + { + ++_src_it; + ++_updates_it; + ++_indices_it; + ++_dst_it; + + return *this; + } + + private: + std::vector::const_iterator _src_it; + std::vector::const_iterator _updates_it; + std::vector::const_iterator _indices_it; + std::vector::const_iterator _dst_it; + }; + + iterator begin() const + { + return iterator(_src_shapes.begin(), _update_shapes.begin(), _indices_shapes.begin(), _dst_shapes.begin()); + } + + int size() const + { + return std::min(_src_shapes.size(), std::min(_indices_shapes.size(), std::min(_update_shapes.size(), _dst_shapes.size()))); + } + + void add_config(TensorShape a, TensorShape b, TensorShape c, TensorShape dst) + { + _src_shapes.emplace_back(std::move(a)); + _update_shapes.emplace_back(std::move(b)); + _indices_shapes.emplace_back(std::move(c)); + _dst_shapes.emplace_back(std::move(dst)); + } + +protected: + ScatterDataset() = default; + ScatterDataset(ScatterDataset &&) = default; + +private: + std::vector _src_shapes{}; + std::vector _update_shapes{}; + std::vector _indices_shapes{}; + std::vector _dst_shapes{}; +}; + +class SmallScatterDataset final : public ScatterDataset +{ +public: + SmallScatterDataset() + { + add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U)); + } +}; +} // namespace datasets +} // namespace test +} // namespace arm_compute +#endif // ACL_TESTS_DATASETS_SCATTERDATASET_H diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp new file mode 100644 index 0000000000..040ca41578 --- /dev/null +++ b/tests/validation/CL/ScatterLayer.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/CL/CLTensor.h" +#include "arm_compute/runtime/CL/functions/CLScatter.h" +#include "tests/validation/fixtures/ScatterLayerFixture.h" +#include "tests/CL/CLAccessor.h" +#include "tests/framework/Macros.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace +{ +constexpr AbsoluteTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */ +constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */ +constexpr AbsoluteTolerance tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for 8-bit asymmetric type */ +constexpr AbsoluteTolerance tolerance_qasymm8_s(1); /**< Tolerance value for comparing reference's output against implementation's output for 8-bit signed asymmetric type */ +} // namespace + +template +using CLScatterLayerFixture = ScatterValidationFixture; + +TEST_SUITE(CL) +TEST_SUITE(ScatterLayer) +TEST_SUITE(Float) +TEST_SUITE(FP32) +TEST_SUITE_END() // FP32 +TEST_SUITE_END() // Float +TEST_SUITE_END() // ScatterLayer +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h new file mode 100644 index 0000000000..750e272388 --- /dev/null +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H + +#include "arm_compute/core/Utils.h" +#include "tests/Globals.h" +#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT +#include "tests/framework/Fixture.h" +#include "tests/validation/Validation.h" +#include "tests/validation/reference/ScatterLayer.h" +#include "tests/SimpleTensor.h" +#include + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +template +class ScatterGenericValidationFixture : public framework::Fixture +{ +public: + void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo()) + { + _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo); + _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo); + } + +protected: + template + void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f) + { + switch(tensor.data_type()) + { + case DataType::F32: + { + std::uniform_real_distribution distribution(lo, hi); + library->fill(tensor, distribution, i); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported data type."); + } + } + } + + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + { + // 1. Create relevant tensors using ScatterInfo data structure. + // ---------------------------------------------------- + // In order - src, updates, indices, output. + TensorType src = create_tensor(shape_a, data_type, 1, a_qinfo); + TensorType updates = create_tensor(shape_b, data_type, 1, a_qinfo); + TensorType indices = create_tensor(shape_c, DataType::U32, 1, QuantizationInfo()); + TensorType dst = create_tensor(out_shape, data_type, 1, o_qinfo); + + FunctionType scatter; + + // Configure operator + scatter.configure(&src, &updates, &indices, &dst, info); + + // Assertions + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(updates.info()->is_resizable()); + ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + + // Allocate tensors + src.allocator()->allocate(); + updates.allocator()->allocate(); + indices.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!updates.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!indices.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill update (a) and indices (b) tensors. + fill(AccessorType(src), 0); + fill(AccessorType(updates), 1); + fill(AccessorType(indices), 2); + + scatter.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type, + ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + { + // Output Quantization not currently in use - fixture should be extended to support this. + ARM_COMPUTE_UNUSED(o_qinfo); + + // Create reference tensors + SimpleTensor src{ a_shape, data_type, 1, a_qinfo }; + SimpleTensor updates{b_shape, data_type, 1, QuantizationInfo() }; + SimpleTensor indices{ c_shape, DataType::U32, 1, QuantizationInfo() }; + + // Fill reference + fill(src, 0); + fill(updates, 1); + fill(indices, 2); + + // Calculate individual reference. + auto result = reference::scatter_layer(src, updates, indices, out_shape, info); + + return result; + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + +// This fixture will use the same shape for updates as indices. +template +class ScatterValidationFixture : public ScatterGenericValidationFixture +{ +public: + void setup(TensorShape src_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) + { + ScatterGenericValidationFixture::setup(src_shape, indices_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); + } +}; + +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif // ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp new file mode 100644 index 0000000000..188cce100b --- /dev/null +++ b/tests/validation/reference/ScatterLayer.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "ScatterLayer.h" +#include "tests/validation/Helpers.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace reference +{ + +template +SimpleTensor scatter_layer_internal(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info) +{ + ARM_COMPUTE_UNUSED(src); + ARM_COMPUTE_UNUSED(updates); + ARM_COMPUTE_UNUSED(indices); + ARM_COMPUTE_UNUSED(info); + // Unimplemented reference. + SimpleTensor dst{ out_shape, src.data_type(), 1 }; + return dst; +} + +template +SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info) +{ + return scatter_layer_internal(src, updates, indices, out_shape, info); +} + +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); + +} // namespace reference +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/reference/ScatterLayer.h b/tests/validation/reference/ScatterLayer.h new file mode 100644 index 0000000000..dc441a8894 --- /dev/null +++ b/tests/validation/reference/ScatterLayer.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H +#define ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H + +#include "Utils.h" +#include "arm_compute/function_info/ScatterInfo.h" +#include "tests/SimpleTensor.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace reference +{ +template +SimpleTensor scatter_layer_internal(const SimpleTensor &src, const SimpleTensor &update, const SimpleTensor &indices, const TensorShape &shape, const ScatterInfo &info); + +template +SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &update, const SimpleTensor &indices, const TensorShape &shape, const ScatterInfo &info); +} // namespace reference +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif // ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index 41ac11801f..2d106d849a 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -49,6 +49,7 @@ #include "arm_compute/function_info/FullyConnectedLayerInfo.h" #include "arm_compute/function_info/GEMMInfo.h" #include "arm_compute/function_info/MatMulInfo.h" +#include "arm_compute/function_info/ScatterInfo.h" #include "arm_compute/runtime/CL/CLTunerTypes.h" #include "arm_compute/runtime/CL/CLTypes.h" #include "arm_compute/runtime/common/LSTMParams.h" @@ -3618,6 +3619,77 @@ inline std::string to_string(const arm_compute::CpuMatMulSettings &settings) return str.str(); } +/** Formatted output of the scatter function type. + * + * @param[out] os Output stream. + * @param[in] function arm_compute::ScatterFunction type to output. + * + * @return Modified output stream. + */ +inline ::std::ostream &operator<<(::std::ostream &os, const ScatterFunction &function) +{ + switch (function) + { + case ScatterFunction::Update: + os << "UPDATE"; + break; + case ScatterFunction::Add: + os << "ADD"; + break; + case ScatterFunction::Sub: + os << "SUB"; + break; + case ScatterFunction::Max: + os << "MAX"; + break; + case ScatterFunction::Min: + os << "MIN"; + break; + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } + return os; +} +/** Formatted output of the arm_compute::ScatterFunction type. + * + * @param[in] func arm_compute::ScatterFunction type to output. + * + * @return Formatted string. + */ +inline std::string to_string(const arm_compute::ScatterFunction &func) +{ + std::stringstream str; + str << func; + return str.str(); +} +/** Formatted output of the arm_compute::ScatterInfo type. + * + * @param[out] os Output stream. + * @param[in] info arm_compute::ScatterInfo type to output. + * + * @return Modified output stream. + */ +inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::ScatterInfo &info) +{ + os << "ScatterInfo=" + << "[" + << "Function=" << info.func << ", " + << "InitialiseZero=" << info.zero_initialization << "] "; + return os; +} +/** Formatted output of the arm_compute::ScatterInfo type. + * + * @param[in] info arm_compute::ScatterInfo type to output. + * + * @return Formatted string. + */ +inline std::string to_string(const arm_compute::ScatterInfo &info) +{ + std::stringstream str; + str << info; + return str.str(); +} + } // namespace arm_compute #endif // ACL_UTILS_TYPEPRINTER_H -- cgit v1.2.1