aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/dynamic_fusion
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/dynamic_fusion')
-rw-r--r--tests/validation/dynamic_fusion/Utils.h73
-rw-r--r--tests/validation/dynamic_fusion/gpu/Integration.cpp642
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Add.cpp264
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Cast.cpp97
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp184
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp474
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp260
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp335
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Mul.cpp221
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp219
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp147
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Resize.cpp359
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp154
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp219
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Sub.cpp262
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp154
16 files changed, 4064 insertions, 0 deletions
diff --git a/tests/validation/dynamic_fusion/Utils.h b/tests/validation/dynamic_fusion/Utils.h
new file mode 100644
index 0000000000..72e9ec5955
--- /dev/null
+++ b/tests/validation/dynamic_fusion/Utils.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef TESTS_VALIDATION_DYNAMIC_FUSION_UTILS
+#define TESTS_VALIDATION_DYNAMIC_FUSION_UTILS
+
+#include "tests/AssetsLibrary.h"
+#include "utils/Utils.h"
+
+#include <chrono>
+#include <limits>
+#include <type_traits>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace utils
+{
+/** A pair of macros which measures the wall clock time, and records it into a map measurement_map with name clock_name
+ *
+ */
+#define TICK(clock_name) \
+ auto clock_name##_tick = std::chrono::high_resolution_clock::now();
+#define TOCK(clock_name, measurement_map) \
+ auto clock_name##_tock = std::chrono::high_resolution_clock::now(); \
+ measurement_map["\"" #clock_name "\""] = duration_cast<microseconds>(clock_name##_tock - clock_name##_tick);
+#define TOCK_AVG(clock_name, measurement_map, num_iterations) \
+ auto clock_name##_tock = std::chrono::high_resolution_clock::now(); \
+ measurement_map["\"" #clock_name "\""] = duration_cast<microseconds>((clock_name##_tock - clock_name##_tick) / (num_iterations));
+
+template <typename T, typename U>
+void fill(U &&tensor, int seed, AssetsLibrary *library)
+{
+ static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+ using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
+
+ DistributionType distribution{ T(-1.0f), T(1.0f) };
+ library->fill(tensor, distribution, seed);
+
+ // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
+ DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
+ library->fill_borders_with_garbage(tensor, distribution_inf, seed);
+}
+} // namespace utils
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+
+#endif /* TESTS_VALIDATION_DYNAMIC_FUSION_UTILS */
diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp
new file mode 100644
index 0000000000..453983c077
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp
@@ -0,0 +1,642 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/QuantizationInfo.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
+#include "arm_compute/dynamic_fusion/sketch/attributes/CastAttributes.h"
+#include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h"
+#include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuCast.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/dynamic_fusion/Utils.h"
+#include "tests/validation/reference/ActivationLayer.h"
+#include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/reference/DepthConvertLayer.h"
+#include "tests/validation/reference/DepthwiseConvolutionLayer.h"
+#include "tests/validation/reference/ElementwiseOperations.h"
+#include "tests/validation/reference/Permute.h"
+#include "tests/validation/reference/PixelWiseMultiplication.h"
+#include "tests/validation/Validation.h"
+
+using namespace arm_compute::experimental::dynamic_fusion;
+using namespace arm_compute::test::validation::utils;
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+TEST_SUITE(CL)
+TEST_SUITE(INTEGRATION)
+TEST_SUITE(DYNAMIC_FUSION)
+
+TEST_CASE(Conv2d, framework::DatasetMode::ALL)
+{
+ /* Computation:
+ * out = conv2d1x1(direct_conv)(input, weights, bias)
+ */
+ CLScheduler::get().default_reinit();
+
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto t_input_shape = TensorShape(384, 12, 12);
+ const auto t_weight_shape = TensorShape(384, 1, 1, 16);
+ const auto t_dst_shape = TensorShape(16, 12, 12);
+
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // Fuse conv2d
+ Conv2dAttributes conv2d_attr{};
+ ITensorInfo *input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+ ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+
+ ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
+
+ ITensorInfo *dst_info = context.create_tensor_info();
+ GpuOutput::create_op(sketch, conv_out_info, dst_info);
+
+ // Configure runtime
+ ClWorkloadRuntime runtime;
+ runtime.configure(sketch);
+
+ // (Important) Allocate auxiliary tensor memory if there are any
+ // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
+ for (auto &data : runtime.get_auxiliary_tensors())
+ {
+ CLTensor *tensor = std::get<0>(data);
+ TensorInfo info = std::get<1>(data);
+ AuxMemoryInfo aux_mem_req = std::get<2>(data);
+ tensor->allocator()->init(info, aux_mem_req.alignment);
+ tensor->allocator()->allocate(); // Use ACL allocated memory
+ // auto buf = cl::Buffer();
+ // tensor->allocator()->import_memory(buf); // Or, import external memory
+ }
+
+ // Construct user tensors
+ CLTensor t_input{};
+ CLTensor t_weight{};
+ CLTensor t_dst{};
+
+ // Initialize user tensors
+ t_input.allocator()->init(*input_info);
+ t_weight.allocator()->init(*weight_info);
+ t_dst.allocator()->init(*dst_info);
+
+ // Allocate and fill user tensors
+ // Instead of using ACL allocator, the user can choose to import memory into the tensors
+ t_input.allocator()->allocate();
+ t_weight.allocator()->allocate();
+ t_dst.allocator()->allocate();
+ fill<float>(CLAccessor(t_input), 0, library.get());
+ fill<float>(CLAccessor(t_weight), 1, library.get());
+
+ // Run runtime
+ runtime.run({&t_input, &t_weight, &t_dst});
+
+ // Create reference
+ SimpleTensor<float> ref_t_input{t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
+ SimpleTensor<float> ref_t_weight{t_weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
+ SimpleTensor<float> ref_t_bias_placeholder{t_dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
+
+ // Fill reference
+ fill<float>(ref_t_input, 0, library.get());
+ fill<float>(ref_t_weight, 1, library.get());
+
+ auto ref_t_input_nchw = reference::permute(ref_t_input, PermutationVector(1U, 2U, 0U));
+ auto ref_t_weight_nchw = reference::permute(ref_t_weight, PermutationVector(1U, 2U, 0U));
+ auto ref_t_bias_placeholder_nchw = reference::permute(ref_t_bias_placeholder, PermutationVector(1U, 2U, 0U));
+ auto t_dst_shape_nchw = t_dst_shape;
+ permute(t_dst_shape_nchw, PermutationVector(1U, 2U, 0U));
+
+ PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left,
+ conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom,
+ DimensionRoundingType{});
+ auto ref_t_dst_nchw = reference::convolution_layer(ref_t_input_nchw, ref_t_weight_nchw, ref_t_bias_placeholder_nchw,
+ t_dst_shape_nchw, legacy_pad_stride, conv2d_attr.dilation());
+ const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U));
+
+ RelativeTolerance<float> tolerance_f32(
+ 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+ validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32);
+}
+
+TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL)
+{
+ /* Computation:
+ * out_0 = in_0 + in_1
+ * out_1 = out_0 + in_2
+ */
+ CLScheduler::get().default_reinit();
+
+ const auto data_type = DataType::F32;
+ const auto t_input_shape = TensorShape(33, 3, 2);
+
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+ ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+ ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
+
+ ITensorInfo *out_0_info = context.create_tensor_info();
+ ITensorInfo *out_1_info = context.create_tensor_info();
+
+ ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info);
+ GpuOutput::create_op(sketch, ans_0_info, out_0_info);
+ ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info);
+ GpuOutput::create_op(sketch, ans_1_info, out_1_info);
+
+ // Configure runtime
+ ClWorkloadRuntime runtime;
+ runtime.configure(sketch);
+
+ // (Important) Allocate auxiliary tensor memory if there are any
+ // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
+ for (auto &data : runtime.get_auxiliary_tensors())
+ {
+ CLTensor *tensor = std::get<0>(data);
+ TensorInfo info = std::get<1>(data);
+ AuxMemoryInfo aux_mem_req = std::get<2>(data);
+ tensor->allocator()->init(info, aux_mem_req.alignment);
+ tensor->allocator()->allocate(); // Use ACL allocated memory
+ // auto buf = cl::Buffer();
+ // tensor->allocator()->import_memory(buf); // Or, import external memory
+ }
+
+ // Construct user tensors
+ CLTensor t_in_0{};
+ CLTensor t_in_1{};
+ CLTensor t_in_2{};
+
+ CLTensor t_out_0{};
+ CLTensor t_out_1{};
+
+ // Initialize user tensors
+ t_in_0.allocator()->init(*in_0_info);
+ t_in_1.allocator()->init(*in_1_info);
+ t_in_2.allocator()->init(*in_2_info);
+
+ t_out_0.allocator()->init(*out_0_info);
+ t_out_1.allocator()->init(*out_1_info);
+
+ // Allocate and fill user tensors
+ // Instead of using ACL allocator, the user can choose to import memory into the tensors
+ t_in_0.allocator()->allocate();
+ t_in_1.allocator()->allocate();
+ t_in_2.allocator()->allocate();
+
+ t_out_0.allocator()->allocate();
+ t_out_1.allocator()->allocate();
+
+ fill<float>(CLAccessor(t_in_0), 0, library.get());
+ fill<float>(CLAccessor(t_in_1), 1, library.get());
+ fill<float>(CLAccessor(t_in_2), 2, library.get());
+
+ // Run runtime
+ runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1});
+
+ // Create reference
+ SimpleTensor<float> ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()};
+
+ SimpleTensor<float> ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_out_1{t_input_shape, data_type, 1, QuantizationInfo()};
+
+ // Fill reference
+ fill<float>(ref_t_in_0, 0, library.get());
+ fill<float>(ref_t_in_1, 1, library.get());
+ fill<float>(ref_t_in_2, 2, library.get());
+
+ reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP);
+ reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_out_1,
+ ConvertPolicy::WRAP);
+
+ RelativeTolerance<float> tolerance_f32(
+ 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+ validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_f32);
+ validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_f32);
+}
+TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL)
+{
+ /* Computation:
+ * out_0 = in_0 + in_1
+ * out_1 = float(int32_t(out_0 + in_2))
+ */
+ CLScheduler::get().default_reinit();
+
+ const auto data_type = DataType::F32;
+ const auto t_input_shape = TensorShape(3, 8, 5);
+
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+ ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+ ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
+
+ ITensorInfo *out_0_info = context.create_tensor_info();
+ ITensorInfo *out_1_info = context.create_tensor_info();
+
+ CastAttributes cast_0_attr;
+ cast_0_attr.data_type(DataType::F16);
+
+ CastAttributes cast_1_attr;
+ cast_1_attr.data_type(DataType::F32);
+
+ ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info);
+ GpuOutput::create_op(sketch, ans_0_info, out_0_info);
+ ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info);
+ ITensorInfo *ans_2_info = GpuCast::create_op(sketch, ans_1_info, cast_0_attr);
+ ITensorInfo *ans_3_info = GpuCast::create_op(sketch, ans_2_info, cast_1_attr);
+ GpuOutput::create_op(sketch, ans_3_info, out_1_info);
+
+ // Configure runtime
+ ClWorkloadRuntime runtime;
+ runtime.configure(sketch);
+
+ // (Important) Allocate auxiliary tensor memory if there are any
+ // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
+ for (auto &data : runtime.get_auxiliary_tensors())
+ {
+ CLTensor *tensor = std::get<0>(data);
+ TensorInfo info = std::get<1>(data);
+ AuxMemoryInfo aux_mem_req = std::get<2>(data);
+ tensor->allocator()->init(info, aux_mem_req.alignment);
+ tensor->allocator()->allocate(); // Use ACL allocated memory
+ // auto buf = cl::Buffer();
+ // tensor->allocator()->import_memory(buf); // Or, import external memory
+ }
+
+ // Construct user tensors
+ CLTensor t_in_0{};
+ CLTensor t_in_1{};
+ CLTensor t_in_2{};
+
+ CLTensor t_out_0{};
+ CLTensor t_out_1{};
+
+ // Initialize user tensors
+ t_in_0.allocator()->init(*in_0_info);
+ t_in_1.allocator()->init(*in_1_info);
+ t_in_2.allocator()->init(*in_2_info);
+
+ t_out_0.allocator()->init(*out_0_info);
+ t_out_1.allocator()->init(*out_1_info);
+
+ // Allocate and fill user tensors
+ // Instead of using ACL allocator, the user can choose to import memory into the tensors
+ t_in_0.allocator()->allocate();
+ t_in_1.allocator()->allocate();
+ t_in_2.allocator()->allocate();
+
+ t_out_0.allocator()->allocate();
+ t_out_1.allocator()->allocate();
+
+ fill<float>(CLAccessor(t_in_0), 0, library.get());
+ fill<float>(CLAccessor(t_in_1), 1, library.get());
+ fill<float>(CLAccessor(t_in_2), 2, library.get());
+
+ // Run runtime
+ runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1});
+
+ // Create reference
+ SimpleTensor<float> ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()};
+
+ SimpleTensor<float> ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()};
+ SimpleTensor<float> ref_t_ans_1{t_input_shape, data_type, 1, QuantizationInfo()};
+
+ // Fill reference
+ fill<float>(ref_t_in_0, 0, library.get());
+ fill<float>(ref_t_in_1, 1, library.get());
+ fill<float>(ref_t_in_2, 2, library.get());
+
+ reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP);
+ reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_ans_1,
+ ConvertPolicy::WRAP);
+ const auto ref_t_ans_2 =
+ reference::depth_convert<float, int32_t>(ref_t_ans_1, DataType::S32, ConvertPolicy::SATURATE, 0);
+ const auto ref_t_out_1 =
+ reference::depth_convert<int32_t, float>(ref_t_ans_2, DataType::F32, ConvertPolicy::SATURATE, 0);
+
+ RelativeTolerance<float> tolerance_add_f32(0.001f);
+ AbsoluteTolerance<float> tolerance_cast_f32(1.0f);
+ validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_add_f32);
+ validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_cast_f32);
+}
+
+/// TODO: COMPMID-6593 : This integration test fails with CKW backend.
+/// It was not enabled for CKW before, therefore went unnoticed.
+TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::DISABLED)
+{
+ // (tensor0)
+ // |
+ // ======|============================================== Sketch 0
+ // | (tensor1) +---- (tensor2)
+ // | | | |
+ // +-- input -- weights -- biases --+ |
+ // | | |
+ // | Conv2d | |
+ // | | |
+ // +----------- output -------------+ |
+ // | |
+ // +-- input ---+ |
+ // | | |
+ // | Sigmoid | |
+ // | | |
+ // +-- output --+ |
+ // | |
+ // +-- input ---+ |
+ // | | |
+ // | Output | |
+ // | | |
+ // +-- output --+ |
+ // | |
+ // (tensor5) |
+ // | |
+ // +--------+ |
+ // ======|=============================|================ Sketch 1
+ // | (tensor3) (tensor4) |
+ // | | | |
+ // +-- input -- weights -- biases --+ |
+ // | | |
+ // | DepthwiseConv2d | |
+ // | | |
+ // +----------- output -------------+ |
+ // | |
+ // +--+ +----------------+
+ // | |
+ // +-- lhs -- rhs --+
+ // | |
+ // | Multiply |
+ // | |
+ // +---- output ----+
+ // |
+ // +-- input ---+
+ // | |
+ // | Output |
+ // | |
+ // +-- output --+
+ // |
+ // (tensor6)
+
+ TensorShape conv2d_src_shape(10, 20, 30);
+ TensorShape conv2d_wei_shape(10, 3, 3, 5);
+ TensorShape conv2d_bia_shape(5);
+ TensorShape conv2d_dst_shape(5, 18, 28);
+ TensorShape dwc_wei_shape(5, 3, 3);
+ TensorShape dwc_bia_shape(5);
+ TensorShape dwc_dst_shape(5, 16, 26);
+
+ // Initialize the context.
+ CLScheduler::get().default_reinit();
+
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context(&cl_compile_ctx);
+
+ auto tensor0_info = context.create_tensor_info(conv2d_src_shape, 1, DataType::F32, DataLayout::NHWC);
+
+ // Create the first sketch: conv2d + cast + output.
+ GpuWorkloadSketch sketch0(&context);
+
+ Conv2dAttributes conv2d_attr;
+ auto tensor1_info = context.create_tensor_info(conv2d_wei_shape, 1, DataType::F32, DataLayout::NHWC);
+ auto tensor2_info = context.create_tensor_info(conv2d_bia_shape, 1, DataType::F32, DataLayout::NHWC);
+ ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr),
+ framework::LogLevel::ERRORS);
+ auto ans_info = GpuConv2d::create_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr);
+
+ ARM_COMPUTE_EXPECT(GpuSigmoid::validate_op(sketch0, ans_info), framework::LogLevel::ERRORS);
+ ans_info = GpuSigmoid::create_op(sketch0, ans_info);
+
+ DepthwiseConv2dAttributes dwc_attr;
+ auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC);
+ auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC);
+ ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, tensor3_info, tensor4_info, dwc_attr),
+ framework::LogLevel::ERRORS);
+
+ auto tensor5_info = context.create_tensor_info();
+ ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, tensor5_info), framework::LogLevel::ERRORS);
+ GpuOutput::create_op(sketch0, ans_info, tensor5_info);
+
+ // Create the first workload runtime.
+ ClWorkloadRuntime runtime0;
+ runtime0.configure(sketch0);
+
+ // Create the second sketch: dwc + sigmoid + output.
+ GpuWorkloadSketch sketch1(&context);
+
+ ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr),
+ framework::LogLevel::ERRORS);
+ ans_info = GpuDepthwiseConv2d::create_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr);
+
+ ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, tensor2_info), framework::LogLevel::ERRORS);
+ ans_info = GpuMul::create_op(sketch1, ans_info, tensor2_info);
+
+ auto tensor6_info = context.create_tensor_info();
+ ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, tensor6_info), framework::LogLevel::ERRORS);
+ GpuOutput::create_op(sketch1, ans_info, tensor6_info);
+
+ // Create the second workload runtime.
+ ClWorkloadRuntime runtime1;
+ runtime1.configure(sketch1);
+
+ // Create the user tensors.
+ CLTensor tensor0;
+ CLTensor tensor1;
+ CLTensor tensor2;
+ CLTensor tensor3;
+ CLTensor tensor4;
+ CLTensor tensor5;
+ CLTensor tensor6;
+
+ tensor0.allocator()->init(*tensor0_info);
+ tensor1.allocator()->init(*tensor1_info);
+ tensor2.allocator()->init(*tensor2_info);
+ tensor3.allocator()->init(*tensor3_info);
+ tensor4.allocator()->init(*tensor4_info);
+ tensor5.allocator()->init(*tensor5_info);
+ tensor6.allocator()->init(*tensor6_info);
+
+ tensor0.allocator()->allocate();
+ tensor1.allocator()->allocate();
+ tensor2.allocator()->allocate();
+ tensor3.allocator()->allocate();
+ tensor4.allocator()->allocate();
+ tensor5.allocator()->allocate();
+ tensor6.allocator()->allocate();
+
+ // Allocate the auxiliary tensors.
+ for (auto &data : runtime0.get_auxiliary_tensors())
+ {
+ auto tensor = std::get<0>(data);
+ auto &tensor_info = std::get<1>(data);
+ auto mem_req = std::get<2>(data);
+
+ tensor->allocator()->init(tensor_info, mem_req.alignment);
+ tensor->allocator()->allocate();
+ }
+
+ for (auto &data : runtime1.get_auxiliary_tensors())
+ {
+ auto tensor = std::get<0>(data);
+ auto &tensor_info = std::get<1>(data);
+ auto mem_req = std::get<2>(data);
+
+ tensor->allocator()->init(tensor_info, mem_req.alignment);
+ tensor->allocator()->allocate();
+ }
+
+ // Fill the input tensors with random data.
+ fill<float>(CLAccessor(tensor0), 0, library.get());
+ fill<float>(CLAccessor(tensor1), 1, library.get());
+ fill<float>(CLAccessor(tensor2), 2, library.get());
+ fill<float>(CLAccessor(tensor3), 3, library.get());
+ fill<float>(CLAccessor(tensor4), 4, library.get());
+
+ // Run each runtime.
+ runtime0.run({&tensor0, &tensor1, &tensor2, &tensor5});
+ runtime1.run({&tensor5, &tensor3, &tensor4, &tensor2, &tensor6});
+
+ // Compute the reference result.
+ SimpleTensor<float> ref_conv2d_src(conv2d_src_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+ SimpleTensor<float> ref_conv2d_wei(conv2d_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+ SimpleTensor<float> ref_conv2d_bia(conv2d_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+ SimpleTensor<float> ref_dwc_wei(dwc_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+ SimpleTensor<float> ref_dwc_bia(dwc_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+
+ fill<float>(ref_conv2d_src, 0, library.get());
+ fill<float>(ref_conv2d_wei, 1, library.get());
+ fill<float>(ref_conv2d_bia, 2, library.get());
+ fill<float>(ref_dwc_wei, 3, library.get());
+ fill<float>(ref_dwc_bia, 4, library.get());
+
+ PermutationVector nhwc_to_nchw(1, 2, 0);
+
+ auto conv2d_dst_shape_nchw = conv2d_dst_shape;
+ permute(conv2d_dst_shape_nchw, nhwc_to_nchw);
+ const auto ref_conv2d_src_nchw = reference::permute(ref_conv2d_src, nhwc_to_nchw);
+ const auto ref_conv2d_wei_nchw = reference::permute(ref_conv2d_wei, nhwc_to_nchw);
+ const auto ref_conv2d_bia_nchw = reference::permute(ref_conv2d_bia, nhwc_to_nchw);
+ const auto ref_conv2d_dst_nchw = reference::convolution_layer(
+ ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo());
+
+ const auto ref_sigmoid_dst_nchw = reference::activation_layer(
+ ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+
+ auto dwc_dst_shape_nchw = dwc_dst_shape;
+ permute(dwc_dst_shape_nchw, nhwc_to_nchw);
+ const auto ref_dwc_wei_nchw = reference::permute(ref_dwc_wei, nhwc_to_nchw);
+ const auto ref_dwc_bia_nchw = reference::permute(ref_dwc_bia, nhwc_to_nchw);
+ const auto ref_dwc_dst_nchw = reference::depthwise_convolution(
+ ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1);
+
+ const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication<float, float, float>(
+ ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP,
+ DataType::F32);
+
+ constexpr RelativeTolerance<float> tolerance(0.001f);
+ validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance);
+}
+
+TEST_SUITE(Invalid_Fusion_Should_Fail)
+TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL)
+{
+ /* Computation:
+ * out = conv2d(conv2d(l0_input, l0_weight), l1_weight)
+ */
+ CLScheduler::get().default_reinit();
+
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto t_input_shape = TensorShape(384, 12, 12);
+ const auto t_weight_shape = TensorShape(384, 1, 1, 16);
+ auto t_input_info = TensorInfo(t_input_shape, 1, data_type, data_layout);
+ auto t_weight_info = TensorInfo(t_weight_shape, 1, data_type, data_layout);
+ auto t_dst_info = TensorInfo();
+
+ Conv2dAttributes conv2d_attr{};
+
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // Create tensor infos
+ ITensorInfo *input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+ ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+ ITensorInfo *dst_info;
+
+ // Fuse conv2d into the workload
+ {
+ // Validate operator
+ const Status success = GpuConv2d::validate_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
+ ARM_COMPUTE_EXPECT(bool(success), framework::LogLevel::ERRORS);
+
+ dst_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
+ }
+
+ // Create tensor infos
+ ITensorInfo *weight_info_2 = context.create_tensor_info(t_weight_info);
+
+ // Fuse conv2d into the workload
+ {
+ // Validate operator, should fail
+ const Status success = GpuConv2d::validate_op(sketch, dst_info, weight_info_2, nullptr, conv2d_attr);
+ const auto expected_error_str = "Operator fusion test failed. This operator cannot be fused into the workload";
+
+ ARM_COMPUTE_EXPECT(!bool(success), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((success.error_description().find(expected_error_str) != std::string::npos),
+ framework::LogLevel::ERRORS);
+ }
+}
+TEST_SUITE_END() // Invalid_Fusion_Should_Fail
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // INTEGRATION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
new file mode 100644
index 0000000000..9bfdc961fe
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/DynamicFusionDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+/* Synced with tests/validation/CL/ArithmeticAddition.cpp from the standard interface.
+ *
+ * Difference | Why the difference
+ * No quantized tests | Not supported yet
+ * No in place tests | Not supported yet
+ * No activation tests | Not needed in dynamic fusion interface
+ *
+ */
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(ADD)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), // S16 is valid data type for Add
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // S32 is valid data type for Add
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32), // Broadcast Y dimension is not allowed
+ TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::S16), // Broadcast Z dimension is not allowed
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32), // Batching is allowed
+ }),
+ framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
+ TensorInfo(TensorShape(15U, 1U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape( 3U, 8U, 1U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
+ })),
+ framework::dataset::make("Expected", { true, false, true, true, false, true, false, false, true, false, false, true})),
+ input1_info, input2_info, expected)
+{
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Validate Elementwise Add
+ auto lhs_info = context.create_tensor_info(input1_info);
+ auto rhs_info = context.create_tensor_info(input2_info);
+
+ bool res = bool(GpuAdd::validate_op(sketch, lhs_info, rhs_info));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+constexpr AbsoluteTolerance<float> tolerance_f(
+ 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 and DataType::F16 */
+constexpr float tolerance_num = 0.0001f; /**< Tolerance number */
+
+template <typename T>
+using DynamicFusionCLAddFixture =
+ DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+
+template <typename T>
+using DynamicFusionCLAddBroadcastFixture =
+ DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+
+template <typename T>
+using DynamicFusionCLAddTwoOpsFixture =
+ DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLAddFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
+ DynamicFusionCLAddFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::LargeShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLAddBroadcastFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::TemporaryLimitedSmallShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
+ DynamicFusionCLAddBroadcastFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::TemporaryLimitedLargeShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(
+ RunSmallTwoOps,
+ DynamicFusionCLAddTwoOpsFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})),
+ framework::dataset::make("FuseTwoOps", {true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLAddFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLAddBroadcastFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::TemporaryLimitedSmallShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLAddFixture<int32_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::S32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+
+TEST_SUITE(S16)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLAddFixture<int16_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::S16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionCLAddFixture<int16_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::LargeShapes()),
+ framework::dataset::make("DataType", {DataType::S16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S16
+
+TEST_SUITE(U8)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLAddFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::U8})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // U8
+
+TEST_SUITE_END() // ADD
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Cast.cpp b/tests/validation/dynamic_fusion/gpu/cl/Cast.cpp
new file mode 100644
index 0000000000..4ef359e74d
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Cast.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuCast.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ConvertPolicyDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+// Tolerance
+constexpr AbsoluteTolerance<float> zero_tolerance(0);
+
+/** Input data sets **/
+
+// F16
+const auto CastF16toF32Dataset = combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F32));
+
+// F32
+const auto CastF32toF16Dataset = combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F16));
+
+class DFConvertPolicies final : public framework::dataset::ContainerDataset<std::vector<ConvertPolicy>>
+{
+public:
+ DFConvertPolicies()
+ : ContainerDataset("ConvertPolicy",
+ {
+ ConvertPolicy::WRAP
+ })
+ {
+ }
+};
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(CAST)
+
+template <typename T>
+using DynamicFusionCLCastToF16Fixture = DynamicFusionCastValidationFixture<CLTensor, CLAccessor, GpuCast, T, half>;
+template <typename T>
+using DynamicFusionCLCastToF32Fixture = DynamicFusionCastValidationFixture<CLTensor, CLAccessor, GpuCast, T, float>;
+
+#define CAST_SUITE(NAME, idt, odt, type, dataset, tolerance) \
+ TEST_SUITE(NAME) \
+ FIXTURE_DATA_TEST_CASE(RunSmall, type, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), dataset), \
+ DFConvertPolicies())) \
+ { \
+ validate(CLAccessor(_target), _reference, tolerance); \
+ } \
+ TEST_SUITE_END()
+
+// F16
+CAST_SUITE(F16_to_F32, DataType::F16, DataType::F32, DynamicFusionCLCastToF32Fixture<half>, CastF16toF32Dataset, zero_tolerance)
+
+// F32
+CAST_SUITE(F32_to_F16, DataType::F32, DataType::F16, DynamicFusionCLCastToF16Fixture<float>, CastF32toF16Dataset, zero_tolerance)
+
+TEST_SUITE_END() // CAST
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
new file mode 100644
index 0000000000..cef8b87c3f
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
@@ -0,0 +1,184 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/sketch/attributes/ClampAttributes.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuClamp.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+constexpr float epsilon = 1e-6f;
+constexpr AbsoluteTolerance<float> tolerance(epsilon);
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(CLAMP)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Minimum value larger than maximum value
+ }),
+ framework::dataset::make("MinVal", { 0.2f,
+ 1.5f,
+ 9.0f,
+ })),
+ framework::dataset::make("MaxVal", { 0.5f,
+ 2.0f,
+ 1.0f,
+ })),
+ framework::dataset::make("Expected", { true, true, false })),
+ input_info, min_val, max_val, expected)
+{
+ // Create a new workload sketch
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Fuse Clamp
+ const ITensorInfo* src_info = context.create_tensor_info(input_info);
+
+ ClampAttributes attributes {};
+ attributes.min_val(min_val)
+ .max_val(max_val);
+
+ const bool res = static_cast<bool>(GpuClamp::validate_op(sketch, src_info, attributes));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionClampOpFixture = DynamicFusionClampValidationFixture<CLTensor, CLAccessor, GpuClamp, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionClampOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})),
+ framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionClampOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::Small5dShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})),
+ framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionClampOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.2f).max_val(0.4f)})),
+ framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionClampOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})),
+ framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionClampOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::Small5dShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})),
+ framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionClampOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make(
+ "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.9f)})),
+ framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance);
+}
+
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+
+TEST_SUITE_END() // CLAMP
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
new file mode 100644
index 0000000000..2f8c639cea
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
@@ -0,0 +1,474 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/DepthwiseConvolutionLayerDataset.h"
+#include "tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+const auto depth_multipliers = framework::dataset::make("DepthMultiplier", {1U, 4U});
+const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", {1, 2, 5, 8});
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(DEPTHWISE_CONV2D)
+
+RelativeTolerance<float> tolerance_f32(
+ 0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+RelativeTolerance<half_float::half> tolerance_f16(half_float::half(
+ 0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr float tolerance_num = 0.02f; /**< Tolerance number */
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip( // Explanations of failing tests
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Mismatching data type input/weights
+ TensorInfo(TensorShape(3U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Mismatching input feature maps
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Mismatching depth multiplier
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Invalid biases size
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Invalid biases dimensions
+ TensorInfo(TensorShape(8U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // dilation < 1
+ TensorInfo(TensorShape(8U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QASYMM8, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QASYMM8_SIGNED, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QSYMM16, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QSYMM8, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QSYMM8_PER_CHANNEL, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::QASYMM16, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::U8, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::S8, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::U16, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::S16, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::U32, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(8U, 32U, 13U), 1, DataType::S32, DataLayout::NHWC), // Unsupported data type
+ TensorInfo(TensorShape(32U, 13U, 8U), 1, DataType::F32, DataLayout::NCHW), // Unsupported data layout
+ TensorInfo(TensorShape(8U, 32U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(8U, 32U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC), // weight dimension > 3
+ TensorInfo(TensorShape(8U, 32U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(8U, 32U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(8U, 32U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ }),
+ framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F16, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(16U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(16U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QASYMM8, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QASYMM8_SIGNED, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QSYMM16, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QSYMM8, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QSYMM8_PER_CHANNEL, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::QASYMM16, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::U8, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::S8, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::U16, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::S16, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::U32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(3U, 3U, 24U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U, 5U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U, 4U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ })),
+ framework::dataset::make("BiasesInfo", { TensorInfo(TensorShape(2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(16U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(16U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::S32, DataLayout::NCHW),
+ TensorInfo(TensorShape(24U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(24U), 1, DataType::F32, DataLayout::NHWC),
+ })),
+ framework::dataset::make("Padding", { Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(0, 0, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(1, 1, 0, 0),
+ Padding2D(2, 1, 2, 1),
+ Padding2D(2, 1, 2, 1),
+ Padding2D(2, 1, 2, 1),
+ })),
+ framework::dataset::make("Stride", { Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(1, 1),
+ Size2D(2, 3),
+ Size2D(2, 3),
+ })),
+ framework::dataset::make("DepthMultiplier", { 1,
+ 1,
+ 3,
+ 1,
+ 1,
+ 2,
+ 2,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ })),
+ framework::dataset::make("Dilation", { Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(0U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(1U, 1U),
+ Size2D(2U, 3U),
+ })),
+ framework::dataset::make("Expected", { false, false, false, false, false, false, true, false,
+ false, false, false, false, false, false, false, false, false, false,
+ false, false, true, false, true, true, true })),
+ input_info, weights_info, biases_info, padding, stride, depth_multiplier, dilation, expected)
+{
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ const ITensorInfo* sketch_input_info = context.create_tensor_info(input_info);
+ const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info);
+ const ITensorInfo* sketch_biases_info = context.create_tensor_info(biases_info);
+
+ DepthwiseConv2dAttributes attributes {};
+ attributes.pad(padding)
+ .stride(stride)
+ .dilation(dilation)
+ .depth_multiplier(depth_multiplier);
+
+ const Status status = GpuDepthwiseConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, attributes);
+ const bool res = bool(status);
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionGpuDepthwiseConv2dFixture =
+ DynamicFusionGpuDepthwiseConv2dValidationFixture<CLTensor, CLAccessor, GpuDepthwiseConv2d, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP16)
+TEST_SUITE(W3x3)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+TEST_SUITE(Dilation)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // Dilation
+TEST_SUITE_END() // W3x3
+
+TEST_SUITE(Generic)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+
+TEST_SUITE(Dilation)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+TEST_SUITE_END() // Dilation
+TEST_SUITE_END() // Generic
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(FP32)
+TEST_SUITE(W3x3)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE(Dilation)
+
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // Dilation
+TEST_SUITE_END() // W3x3
+
+TEST_SUITE(Generic)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeKernelSize,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::LargeKernelSizeDepthwiseConvolutionLayerNHWCDataset(),
+ framework::dataset::make("DepthMultiplier", {1})),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE(Dilation)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuDepthwiseConv2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // Dilation
+TEST_SUITE_END() // Generic
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // DEPTHWISE_CONV2D
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
new file mode 100644
index 0000000000..b843764786
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "tests/AssetsLibrary.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/SmallConvolutionLayerDataset.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h"
+#include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+/** Tolerances from tests/validation/CL/DirectConvolutionLayer.cpp
+ */
+RelativeTolerance<float> tolerance_f32(
+ 0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+RelativeTolerance<half_float::half> tolerance_f16(half_float::half(
+ 0.2)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr float abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/
+constexpr float tolerance_num = 0.07f; /**< Tolerance number */
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+/** Synced with tests/validation/CL/ConvolutionLayer.cpp
+ *
+ * Difference | Why the difference
+ * f32 tolerance here is smaller | To use the same tolerance as that of DirectConv2d; lowering tolerance is safe
+ * No quantized tests | Not supported yet
+ * No grouped CNN tests | Not supported yet
+ * No mixed layout tests | Not needed; only NHWC is supported
+ * No activation | Not needed in fusion
+ * No ValidateConvolutionMethod | Only a single method (direct conv2d) is supported
+ * No ReshapeWeights = true tests | Not applicable yet. This parameter only concerns gemm-based conv2d
+ * No RunSmallWithPadding tests | Padding is removed
+ *
+ */
+TEST_SUITE(CONV2D)
+
+template <typename T>
+using DynamicFusionGpuConv2dFixture = DynamicFusionGpuConv2dValidationFixture<CLTensor, CLAccessor, GpuConv2d, T>;
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuConv2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})),
+ framework::dataset::make("QuantizationInfo", QuantizationInfo())))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuConv2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", {DataLayout::NHWC})),
+ framework::dataset::make("QuantizationInfo", QuantizationInfo())))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+TEST_SUITE_END() // FP16
+
+// Tests for specific conv2d methods
+/** Synced with tests/validation/CL/DirectConvolutionLayer.cpp
+ *
+ * Difference | Why the difference
+ * No quantized tests | Not supported yet
+ * No Invalid output size test | Not applicable. Output is removed from the interface
+ * No mixed layout/NCHW tests | Not needed; only NHWC is supported
+ * No activation tests | Not needed in fusion
+ */
+TEST_SUITE(DIRECT_CONV2D)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Invalid: Mismatching data type input/weights
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Invalid: Mismatching input feature maps
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Invalid weights dimensions
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Unsupported biases size
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Unsupported biases dimensions
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, DataLayout::NCHW), // Unsupported data layout: NCHW
+ TensorInfo(TensorShape(2U, 32U, 16U), 1, DataType::QASYMM8, DataLayout::NHWC), // Unsupported data type: quantized
+ TensorInfo(TensorShape(2U, 32U, 16U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Arbitrary weight sizes for NHWC are supported
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Non-rectangular weights dimensions for NHWC are supported
+ TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::F32, DataLayout::NHWC), // Strides > 2 for any kernel sizes for NHWC are supported
+ }),
+ framework::dataset::make("WeightsInfo",{ TensorInfo(TensorShape(2U, 3U, 3U, 4U), 1, DataType::F16, DataLayout::NHWC),
+ TensorInfo(TensorShape(3U, 3U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 4U, 3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(2U, 1U, 1U, 4U), 1, DataType::QASYMM8, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 1U, 1U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 13U, 13U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 5U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(2U, 3U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC),
+ })),
+ framework::dataset::make("BiasesInfo",{ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(3U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U, 2U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(25U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(4U), 1, DataType::QASYMM8, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC),
+ })),
+ framework::dataset::make("Conv2dAttributes", {
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({1, 1}).pad({0, 0, 0, 0}),
+ Conv2dAttributes().stride({3, 3}).pad({0, 0, 0, 0}),
+ })),
+ framework::dataset::make("Expected", { false, false, false, false, false, false, false, true, true, true, true })),
+ input_info, weights_info, biases_info, conv2d_attrs, expected)
+{
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ const ITensorInfo* sketch_input_info = context.create_tensor_info(input_info);
+ const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info);
+ const ITensorInfo* sketch_biases_info = context.create_tensor_info(biases_info);
+ bool is_valid = bool(GpuConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, conv2d_attrs));
+ ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
+}
+template <typename T>
+using DynamicFusionGpuDirectConv2dFixture = DynamicFusionDirectConv2dValidationFixture<CLTensor, CLAccessor, GpuConv2d, T>;
+
+TEST_SUITE(FP16)
+/// TODO: COMPMID-6877: Once the issue in Conv2d is resolved, re-enable these
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<half>, framework::DatasetMode::DISABLED,
+ combine(combine(combine(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(27U, 13U, 23U),
+ TensorShape(19U, 5U, 16U, 4U),
+ TensorShape(13U, 5U, 17U, 2U),
+ TensorShape(32U, 37U, 13U) } ),
+ framework::dataset::make("StrideX", { 1, 3, 1, 1 })),
+ framework::dataset::make("StrideY", { 1, 3, 2, 1 })),
+ framework::dataset::make("PadX", { 1, 3, 0, 4 })),
+ framework::dataset::make("PadY", { 1, 3, 0, 4 })),
+ framework::dataset::make("KernelSize", { 3, 8, 1, 9 })),
+ framework::dataset::make("NumKernels", { 17, 3, 1, 19 })),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDirectConv2dFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(800U, 800U, 3U) } ),
+ framework::dataset::make("StrideX", { 1 })),
+ framework::dataset::make("StrideY", { 1 })),
+ framework::dataset::make("PadX", { 1 })),
+ framework::dataset::make("PadY", { 1 })),
+ framework::dataset::make("KernelSize", { 9 })),
+ framework::dataset::make("NumKernels", { 3 })),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(FP32)
+/// TODO: COMPMID-6877: Once the issue in Conv2d is resolved, re-enable these
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<float>, framework::DatasetMode::DISABLED,
+ combine(combine(combine(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(27U, 13U, 23U),
+ TensorShape(19U, 5U, 16U, 4U),
+ TensorShape(13U, 5U, 17U, 2U),
+ TensorShape(32U, 37U, 13U) } ),
+ framework::dataset::make("StrideX", { 1, 3, 1, 1 })),
+ framework::dataset::make("StrideY", { 1, 3, 2, 1 })),
+ framework::dataset::make("PadX", { 1, 3, 0, 4 })),
+ framework::dataset::make("PadY", { 1, 3, 0, 4 })),
+ framework::dataset::make("KernelSize", { 3, 8, 1, 9 })),
+ framework::dataset::make("NumKernels", { 17, 3, 1, 19 })),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.0, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDirectConv2dFixture<float>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(800U, 800U, 3U) } ),
+ framework::dataset::make("StrideX", { 1 })),
+ framework::dataset::make("StrideY", { 1 })),
+ framework::dataset::make("PadX", { 1 })),
+ framework::dataset::make("PadY", { 1 })),
+ framework::dataset::make("KernelSize", { 9 })),
+ framework::dataset::make("NumKernels", { 3 })),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.0, abs_tolerance_f32);
+}
+// clang-format on
+// *INDENT-ON*
+
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // DIRECT_CONV2D
+TEST_SUITE_END() // CONV2D
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
new file mode 100644
index 0000000000..82d66ca6ce
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
@@ -0,0 +1,335 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "tests/AssetsLibrary.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/LargeMatMulDataset.h"
+#include "tests/datasets/MatMulDataset.h"
+#include "tests/datasets/SmallMatMulDataset.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h"
+#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/Permute.h"
+#include "tests/validation/Validation.h"
+
+#include <tuple>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(
+ 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+constexpr float abs_tolerance_f32(
+ 0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
+constexpr float abs_tolerance_f16(
+ 0.001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16 data types in case using relative tolerance fails because of small values */
+RelativeTolerance<half_float::half> tolerance_f16(half(
+ 0.02)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+} // namespace
+
+/** M0 values to test - precommit */
+const auto m0_values_lhs_nt_precommit = framework::dataset::make("M0", {1, 2, 3});
+
+/** N0 values to test - precommit */
+const auto n0_values_rhs_t_precommit = framework::dataset::make("N0", {1, 2, 4});
+
+/** K0 values to test - precommit */
+const auto k0_values_rhs_t_precommit = framework::dataset::make("K0", {1, 2, 4});
+
+/** M0 values to test - nightly */
+const auto m0_values_lhs_nt_nightly = framework::dataset::make("M0", {1, 2, 3, 4});
+
+/** N0 values to test - nightly */
+const auto n0_values_rhs_t_nightly = framework::dataset::make("N0", {1, 2, 3, 4, 8});
+
+/** K0 values to test - nightly */
+const auto k0_values_rhs_t_nightly = framework::dataset::make("K0", {1, 2, 3, 4, 8});
+
+class DFMatMulDataset final : public datasets::MatMulDataset
+{
+public:
+ DFMatMulDataset()
+ {
+ // LHS = [K, M], RHS = [N, K], DST = [N, M]
+ add_config(TensorShape(1U, 1U), TensorShape(1U, 1U), TensorShape(1U, 1U));
+ add_config(TensorShape(1U, 2U), TensorShape(2U, 1U), TensorShape(2U, 2U));
+ add_config(TensorShape(9U, 6U), TensorShape(5U, 9U), TensorShape(5U, 6U));
+ add_config(TensorShape(32U, 37U), TensorShape(17U, 32U), TensorShape(17U, 37U));
+ }
+};
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+
+TEST_SUITE(MatMul)
+
+TEST_SUITE(Validate)
+TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
+{
+ using MatMulConfigurationPair = std::pair<MatMulKernelInfo, bool>;
+
+ const std::vector<MatMulConfigurationPair> supported_block_sizes = {
+ // MatMulKernelInfo(adj_lhs, adj_rhs, M0, N0, K0, export_rhs_to_cl_image = false)
+
+ // Lhs not-transposed, Rhs transposed
+ {MatMulKernelInfo(false, true, 0, 1, 1), false}, // M0 should be > 0
+ {MatMulKernelInfo(false, true, 3, 11, 1), false}, // N0 not in {1, 2, 3, 4, 8, 16}
+ {MatMulKernelInfo(false, true, 3, 7, 1), false}, // N0 not in {1, 2, 3, 4, 8, 16}
+ {MatMulKernelInfo(false, true, 3, 3, 12), false}, // K0 not in {1, 2, 3, 4, 8, 16}
+ {MatMulKernelInfo(false, true, 3, 3, 6), false}, // K0 not in {1, 2, 3, 4, 8, 16}
+ {MatMulKernelInfo(false, true, 5, 1, 2), true}, {MatMulKernelInfo(false, true, 3, 3, 3), true},
+ {MatMulKernelInfo(false, true, 2, 4, 8), true},
+
+ };
+
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // Set big enough shapes so that block sizes are not truncated. Also, set all dimensions equal
+ // so that it doesn't fail for different NT/T configurations. We aim to test the block sizes here,
+ // not the shapes themselves.
+ const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
+ const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
+
+ for (auto &pair : supported_block_sizes)
+ {
+ MatMulAttributes matmul_attr{};
+ matmul_attr.adj_lhs(pair.first.adj_lhs);
+ matmul_attr.adj_rhs(pair.first.adj_rhs);
+
+ GpuMatMulSettings matmul_settings{};
+ matmul_settings.m0(pair.first.m0);
+ matmul_settings.n0(pair.first.n0);
+ matmul_settings.k0(pair.first.k0);
+
+ Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
+ ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS);
+ }
+}
+
+TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
+{
+ // Create a sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
+ using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+ const std::vector<ShapeConfigurationTuple> shape_configurations = {
+ {TensorShape(5U, 1U), TensorShape(3U, 5U), true},
+ {TensorShape(10U, 12U), TensorShape(3U, 10U), true},
+ {TensorShape(8U, 4U), TensorShape(2U, 8U), true},
+ {TensorShape(8U, 4U), TensorShape(2U, 5U), false}, // Mismatch in the K dimension
+ {TensorShape(5U, 0U), TensorShape(2U, 5U), false}, // Invalid dimension
+ {TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true},
+ {TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false}, // no batch broadcasting
+ {TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U),
+ false}, // mismatch in batch dimension
+ };
+
+ for (auto &tuple : shape_configurations)
+ {
+ const bool expected = std::get<2>(tuple);
+
+ for (bool adj_lhs : {false})
+ {
+ for (bool adj_rhs : {true})
+ {
+ TensorShape lhs_shape = std::get<0>(tuple);
+ TensorShape rhs_shape = std::get<1>(tuple);
+
+ if (adj_lhs)
+ {
+ permute(lhs_shape, PermutationVector(1U, 0U));
+ }
+
+ if (adj_rhs)
+ {
+ permute(rhs_shape, PermutationVector(1U, 0U));
+ }
+
+ const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(lhs_shape, 1, DataType::F32));
+ const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(rhs_shape, 1, DataType::F32));
+
+ MatMulAttributes matmul_attr{};
+ matmul_attr.adj_lhs(adj_lhs);
+ matmul_attr.adj_rhs(adj_rhs);
+
+ GpuMatMulSettings matmul_settings{};
+ matmul_settings.m0(1);
+ matmul_settings.n0(1);
+ matmul_settings.k0(1);
+
+ Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+ }
+ }
+ }
+}
+
+TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
+{
+ // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
+ using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
+ const std::vector<DataTypeConfigurationTuple> data_type_configurations = {
+ {DataType::F32, DataType::F32, DataType::F32, true},
+ {DataType::F16, DataType::F16, DataType::F16, true},
+ {DataType::F16, DataType::F32, DataType::F32, false}, // no mixed precision
+ {DataType::F64, DataType::F64, DataType::F64, false}, // no double precision
+ {DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false}, // no quantized types
+ {DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false}, // no quantized types
+ {DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL,
+ false}, // no quantized types
+ {DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false}, // no quantized types
+ {DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false}, // no quantized types
+ {DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false}, // no quantized types
+ {DataType::S64, DataType::S64, DataType::S64, false}, // no integral types
+ {DataType::S32, DataType::S32, DataType::S32, false}, // no integral types
+ {DataType::S16, DataType::S16, DataType::S16, false}, // no integral types
+ {DataType::S8, DataType::S8, DataType::S8, false}, // no integral types
+ {DataType::U64, DataType::U64, DataType::U64, false}, // no integral types
+ {DataType::U32, DataType::U32, DataType::U32, false}, // no integral types
+ {DataType::U16, DataType::U16, DataType::U16, false}, // no integral types
+ {DataType::U8, DataType::U8, DataType::U8, false}, // no integral types
+ };
+ // Create a sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const TensorShape shape = TensorShape(10U, 10U);
+ MatMulAttributes matmul_attr{};
+ matmul_attr.adj_lhs(false);
+ matmul_attr.adj_rhs(false);
+ GpuMatMulSettings matmul_settings{};
+ matmul_settings.m0(1);
+ matmul_settings.n0(1);
+ matmul_settings.k0(1);
+
+ for (auto &tuple : data_type_configurations)
+ {
+ const bool expected = std::get<3>(tuple);
+
+ const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<0>(tuple)));
+ const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<1>(tuple)));
+
+ Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+ }
+}
+
+TEST_SUITE_END() // Validate
+
+template <typename T>
+using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture<CLTensor, CLAccessor, GpuMatMul, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+
+FIXTURE_DATA_TEST_CASE(RunPrecommit,
+ DynamicFusionGpuMatmulFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(DFMatMulDataset(),
+ framework::dataset::make("TransposeA", {false}),
+ framework::dataset::make("TransposeB", {true}),
+ m0_values_lhs_nt_precommit,
+ n0_values_rhs_t_precommit,
+ k0_values_rhs_t_precommit,
+ framework::dataset::make("ExportRhsToCLImage", {false}),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunNightly,
+ DynamicFusionGpuMatmulFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(DFMatMulDataset(),
+ framework::dataset::make("TransposeA", {false}),
+ framework::dataset::make("TransposeB", {true}),
+ m0_values_lhs_nt_nightly,
+ n0_values_rhs_t_nightly,
+ k0_values_rhs_t_nightly,
+ framework::dataset::make("ExportRhsToCLImage", {false}),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+
+FIXTURE_DATA_TEST_CASE(RunPrecommit,
+ DynamicFusionGpuMatmulFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(DFMatMulDataset(),
+ framework::dataset::make("TransposeA", {false}),
+ framework::dataset::make("TransposeB", {true}),
+ m0_values_lhs_nt_precommit,
+ n0_values_rhs_t_precommit,
+ k0_values_rhs_t_precommit,
+ framework::dataset::make("ExportRhsToCLImage", {false}),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunNightly,
+ DynamicFusionGpuMatmulFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(DFMatMulDataset(),
+ framework::dataset::make("TransposeA", {false}),
+ framework::dataset::make("TransposeB", {true}),
+ m0_values_lhs_nt_nightly,
+ n0_values_rhs_t_nightly,
+ k0_values_rhs_t_nightly,
+ framework::dataset::make("ExportRhsToCLImage", {false}),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // MatMul
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
new file mode 100644
index 0000000000..af02ce3eaa
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
@@ -0,0 +1,221 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/DynamicFusionDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+/* Synced with tests/validation/CL/PixelwiseMultiplication.cpp from the standard interface.
+ *
+ * Difference | Why the difference
+ * No integer tests | Not supported yet
+ * No quantized tests | Not supported yet
+ * No convert policy tests | Not needed as convert policy is ignored by floating types
+ * No scale tests | Not supported yet
+ * No rounding modes tests | Not supported yet
+ * No in place tests | Not supported yet
+ * No activation tests | Not needed in dynamic fusion interface
+ *
+ */
+namespace
+{
+constexpr AbsoluteTolerance<float> tolerance_f16(
+ 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr AbsoluteTolerance<float> tolerance_f32(
+ 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+} // namespace
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(MUL)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Unsupported data type U8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Unsupported data type S8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), // Unsupported data type S16
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // Unsupported data type S32
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED), // Unsupported data type QASYMM8_SIGNED
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32), // Broadcast Y dimension is not allowed
+ TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::F32), // Broadcast Z dimension is not allowed
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32), // Batching is allowed
+ }),
+ framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
+ TensorInfo(TensorShape(15U, 1U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape( 3U, 8U, 1U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
+ })),
+ framework::dataset::make("Expected", { true, true, false, false, false, false, false, false, false, false, true, true, false, false, true })),
+ input1_info, input2_info, expected)
+{
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Validate Elementwise Mul
+ auto lhs_info = context.create_tensor_info(input1_info);
+ auto rhs_info = context.create_tensor_info(input2_info);
+
+ bool res = bool(GpuMul::validate_op(sketch, lhs_info, rhs_info));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionCLMulFixture = DynamicFusionMulOneOpValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
+template <typename T>
+using DynamicFusionCLMulBroadcastFixture = DynamicFusionMulBroadcastValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
+template <typename T>
+using DynamicFusionCLMulTwoOpsFixture = DynamicFusionMulTwoOpsValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
+
+TEST_SUITE(F16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLMulFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLMulBroadcastFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
+ DynamicFusionCLMulBroadcastFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // F16
+
+TEST_SUITE(F32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLMulFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
+ DynamicFusionCLMulFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLMulBroadcastFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
+ DynamicFusionCLMulBroadcastFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionCLMulTwoOpsFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes(),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})),
+ framework::dataset::make("FuseTwoOps", {true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // F32
+
+TEST_SUITE_END() // MUL
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
new file mode 100644
index 0000000000..be816b32b3
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
@@ -0,0 +1,219 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/dynamic_fusion/PoolingLayerDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(POOL2D)
+
+constexpr AbsoluteTolerance<float> tolerance_f32(
+ 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */
+constexpr AbsoluteTolerance<float> tolerance_f16(
+ 0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */
+
+const auto PoolingLayerDatasetFP =
+ combine(combine(combine(combine(framework::dataset::make("PoolingType", {PoolingType::MAX, PoolingType::AVG}),
+ framework::dataset::make("PoolingSize", {Size2D(2, 2), Size2D(3, 3)})),
+ framework::dataset::make("Pad", {Padding2D()})),
+ framework::dataset::make("Stride", {Size2D(1, 1), Size2D(2, 1), Size2D(5, 7)})),
+ framework::dataset::make("ExcludePadding", {true}));
+
+template <typename T>
+using DynamicFusionGpuPool2dFixture = DynamicFusionGpuPool2dValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
+
+template <typename T>
+using DFSpecialGpuPool2dFixture = DynamicFusionGpuPool2dSpecialValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
+// *INDENT-OFF*
+// clang-format off
+
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(2U, 27U, 13U), 1, DataType::QASYMM8, DataLayout::NHWC), // Invalid parameters, unsupported pooling
+ TensorInfo(TensorShape(5U, 15U, 13U), 1, DataType::F32, DataLayout::NHWC), // Valid Non-rectangular Global Pooling
+ TensorInfo(TensorShape(5U, 13U, 13U), 1, DataType::QASYMM8, DataLayout::NHWC), // Invalid - Quantized not supported.
+ TensorInfo(TensorShape(5U, 13U, 13U), 1, DataType::F32, DataLayout::NHWC), // Valid global pooling
+ TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::F32, DataLayout::NCHW), // Unsupported data layout
+ }),
+ framework::dataset::make("Pool2dAttributes", {
+ Pool2dAttributes().pool_type(PoolingType::L2).pool_size(Size2D(3,3)).pad(Padding2D(0,0,0,0)).stride(Size2D(1,1)),
+ Pool2dAttributes().pool_type(PoolingType::AVG).pool_size(Size2D(15U, 13U)),
+ Pool2dAttributes().pool_type(PoolingType::AVG).pool_size(Size2D(2,2)).pad(Padding2D()).stride(Size2D(1,1)),
+ Pool2dAttributes().pool_type(PoolingType::AVG).pool_size(Size2D(13U,13U)),
+ Pool2dAttributes().pool_type(PoolingType::AVG).pool_size(Size2D(13U,13U)),
+ })),
+ framework::dataset::make("Expected", { false, true, false, true, false })),
+ input_info, pool2d_attr, expected)
+{
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Declare GpuPool2d settings
+ const GpuPool2dSettings &settings = GpuPool2dSettings();
+
+ // Validate Pool2d Configuration
+ auto src_info = context.create_tensor_info(input_info);
+ bool res = bool(GpuPool2d::validate_op(sketch, src_info, pool2d_attr, settings));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+
+// clang-format on
+// *INDENT-ON*
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuPool2dFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionGpuPool2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunSpecial,
+ DFSpecialGpuPool2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(datasets::PoolingLayerDatasetSpecialDynamicFusion(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE(GlobalPooling)
+FIXTURE_DATA_TEST_CASE(
+ RunSmall,
+ DynamicFusionGpuPool2dFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+ {TensorShape(27U, 13U, 2U),
+ TensorShape(27U, 13U, 2U, 4U)}),
+ framework::dataset::make("PoolingType",
+ {PoolingType::AVG, PoolingType::MAX})),
+ framework::dataset::make("PoolingSize", {Size2D(27, 13)})),
+ framework::dataset::make("Pad", {Padding2D()})),
+ framework::dataset::make("Stride", {Size2D(1, 1)})),
+ framework::dataset::make("ExcludePadding", true)),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(
+ RunLarge,
+ DynamicFusionGpuPool2dFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+ {TensorShape(79U, 37U, 11U),
+ TensorShape(79U, 37U, 11U, 4U)}),
+ framework::dataset::make("PoolingType",
+ {PoolingType::AVG, PoolingType::MAX})),
+ framework::dataset::make("PoolingSize", {Size2D(79, 37)})),
+ framework::dataset::make("Pad", {Padding2D()})),
+ framework::dataset::make("Stride", {Size2D(1, 1)})),
+ framework::dataset::make("ExcludePadding", true)),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // GlobalPooling
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+TEST_SUITE(GlobalPooling)
+FIXTURE_DATA_TEST_CASE(
+ RunSmall,
+ DynamicFusionGpuPool2dFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+ {TensorShape(27U, 13U, 2U),
+ TensorShape(27U, 13U, 2U, 4U)}),
+ framework::dataset::make("PoolingType",
+ {PoolingType::AVG, PoolingType::MAX})),
+ framework::dataset::make("PoolingSize", {Size2D(27, 13)})),
+ framework::dataset::make("Pad", {Padding2D()})),
+ framework::dataset::make("Stride", {Size2D(1, 1)})),
+ framework::dataset::make("ExcludePadding", true)),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(
+ RunLarge,
+ DynamicFusionGpuPool2dFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+ {TensorShape(79U, 37U, 11U),
+ TensorShape(79U, 37U, 11U, 4U)}),
+ framework::dataset::make("PoolingType",
+ {PoolingType::AVG, PoolingType::MAX})),
+ framework::dataset::make("PoolingSize", {Size2D(79, 37)})),
+ framework::dataset::make("Pad", {Padding2D()})),
+ framework::dataset::make("Stride", {Size2D(1, 1)})),
+ framework::dataset::make("ExcludePadding", true)),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // GlobalPooling
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // FLOAT
+
+TEST_SUITE_END() // POOL2D
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
new file mode 100644
index 0000000000..d46754ccca
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ReshapeLayerDataset.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(RESHAPE)
+
+DATA_TEST_CASE(Validate,
+ framework::DatasetMode::DISABLED,
+ zip(zip(framework::dataset::make(
+ "InputInfo",
+ {
+ TensorInfo(TensorShape(9U, 5U, 7U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32) /*mismatching dimensions*/,
+ }),
+ framework::dataset::make("OutputShape",
+ {
+ TensorShape(9U, 5U, 21U),
+ TensorShape(8U, 24U, 4U),
+ TensorShape(192U, 192U),
+ })),
+ framework::dataset::make("Expected", {true, true, false})),
+ input_info,
+ output_shape,
+ expected)
+{
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // Create sketch tensors
+ TensorShape input_shape = input_info.tensor_shape();
+ ARM_COMPUTE_UNUSED(input_shape);
+ ITensorInfo *src_info = context.create_tensor_info(input_info);
+
+ ReshapeAttributes attributes;
+ attributes.shape(output_shape);
+ Status status = GpuReshape::validate_op(sketch, src_info, attributes);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+}
+
+template <typename T>
+using DynamicFusionGpuReshapeLayerFixture =
+ DynamicFusionGpuReshapeLayerValidationFixture<CLTensor, CLAccessor, GpuReshape, T>;
+
+TEST_SUITE(F32)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuReshapeLayerFixture<float>,
+ framework::DatasetMode::DISABLED,
+ combine(datasets::SmallReshapeLayerDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // F32
+
+TEST_SUITE(F16)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuReshapeLayerFixture<half>,
+ framework::DatasetMode::DISABLED,
+ combine(datasets::SmallReshapeLayerDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // F16
+
+TEST_SUITE(U8)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuReshapeLayerFixture<uint8_t>,
+ framework::DatasetMode::DISABLED,
+ combine(datasets::SmallReshapeLayerDataset(),
+ framework::dataset::make("DataType", DataType::U8)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // U8
+
+TEST_SUITE(S8)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuReshapeLayerFixture<int8_t>,
+ framework::DatasetMode::DISABLED,
+ combine(datasets::SmallReshapeLayerDataset(),
+ framework::dataset::make("DataType", DataType::S8)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S8
+
+TEST_SUITE(S16)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionGpuReshapeLayerFixture<int16_t>,
+ framework::DatasetMode::DISABLED,
+ combine(datasets::SmallReshapeLayerDataset(),
+ framework::dataset::make("DataType", DataType::S16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S16
+
+TEST_SUITE_END() // RESHAPE
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
new file mode 100644
index 0000000000..a6bcf4ae26
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
@@ -0,0 +1,359 @@
+/*
+* Copyright (c) 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuResize.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ScaleValidationDataset.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h"
+#include "tests/validation/Validation.h"
+
+using namespace arm_compute::experimental::dynamic_fusion;
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+using datasets::ScaleAlignCornersSamplingPolicySet;
+using datasets::ScaleInterpolationPolicySet;
+using datasets::ScaleSamplingPolicySet;
+using datasets::ScaleShapesBaseDataSet;
+
+/** We consider vector size in byte 16 since the maximum size of
+ * a vector used by @ref CLScaleKernel is currently 16-byte (float4).
+ */
+constexpr uint32_t vector_byte = 16;
+
+template <typename T>
+constexpr uint32_t num_elements_per_vector()
+{
+ return vector_byte / sizeof(T);
+}
+
+/** Quantization information data set */
+const auto QuantizationInfoSet = framework::dataset::make("QuantizationInfo",
+ {
+ QuantizationInfo(0.5f, -1),
+ });
+
+/** Tolerance */
+constexpr float tolerance_f32_absolute(0.001f);
+
+RelativeTolerance<float> tolerance_f32(0.05);
+constexpr float abs_tolerance_f16(0.1f);
+RelativeTolerance<half> tolerance_f16(half(0.1));
+
+constexpr float tolerance_num_f32(0.01f);
+
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(RESIZE)
+
+TEST_SUITE(Validate)
+
+const auto default_input_shape = TensorShape{2, 3, 3, 2};
+const auto default_output_shape = TensorShape{4, 6, 3, 2};
+
+constexpr auto default_data_type = DataType::U8;
+constexpr auto default_data_layout = DataLayout::NHWC;
+
+TEST_CASE(NullPtr, framework::DatasetMode::ALL)
+{
+ const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+ const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout};
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ // nullptr is given as input
+ Status status = GpuResize::validate_op(sketch, nullptr, ResizeAttributes());
+ ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(SupportDataType, framework::DatasetMode::ALL)
+{
+ const std::map<DataType, bool> supported_data_types =
+ {
+ { DataType::U8, false },
+ { DataType::S8, false },
+ { DataType::QSYMM8, false },
+ { DataType::QASYMM8, false },
+ { DataType::QASYMM8_SIGNED, false },
+ { DataType::QSYMM8_PER_CHANNEL, false },
+ { DataType::U16, false },
+ { DataType::S16, false },
+ { DataType::QSYMM16, false },
+ { DataType::QASYMM16, false },
+ { DataType::U32, false },
+ { DataType::S32, false },
+ { DataType::U64, false },
+ { DataType::S64, false },
+ { DataType::BFLOAT16, false },
+ { DataType::F16, true },
+ { DataType::F32, true },
+ { DataType::F64, false },
+ { DataType::SIZET, false },
+ };
+
+ for (auto &kv : supported_data_types)
+ {
+ const TensorInfo input_info = TensorInfo{default_input_shape, 1, kv.first, default_data_layout};
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
+
+ ResizeAttributes attributes;
+ attributes.output_width(default_output_shape[0]); // shape is not important unless it's empty
+ attributes.output_height(default_output_shape[1]);
+
+ Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
+ ARM_COMPUTE_EXPECT(bool(status) == kv.second, framework::LogLevel::ERRORS);
+ }
+}
+
+TEST_CASE(MismatchingDataType, framework::DatasetMode::ALL)
+{
+ constexpr DataType non_default_data_type = DataType::F32;
+
+ const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+ const TensorInfo output_info = TensorInfo{default_output_shape, 1, non_default_data_type, default_data_layout};
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
+
+ Status status = GpuResize::validate_op(sketch, sketch_input_info, ResizeAttributes());
+ ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(AlignedCornerNotSupported, framework::DatasetMode::ALL)
+{
+ // Aligned corners require sampling policy to be TOP_LEFT.
+ constexpr InterpolationPolicy interpolation_policy = InterpolationPolicy::BILINEAR;
+ constexpr bool align_corners = true;
+ constexpr SamplingPolicy sampling_policy = SamplingPolicy::CENTER;
+
+ const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+ const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout};
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
+
+ ResizeAttributes attributes{};
+ attributes.interpolation_policy(interpolation_policy).sampling_policy(sampling_policy).align_corners(align_corners);
+
+ Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
+ ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(UnsupportedInterpolationPolicy, framework::DatasetMode::ALL)
+{
+ const TensorInfo input_info = TensorInfo{TensorShape(28U, 33U, 2U), 1, DataType::F32, default_data_layout};
+ const TensorInfo output_info = TensorInfo{TensorShape(26U, 21U, 2U), 1, DataType::F32, default_data_layout};
+ constexpr auto interpolation_policy = InterpolationPolicy::AREA;
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
+
+ ResizeAttributes attributes{};
+ attributes.interpolation_policy(interpolation_policy);
+
+ Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
+ ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(UnsupportedLayout, framework::DatasetMode::ALL)
+{
+ const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, DataLayout::NCHW};
+ const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, DataLayout::NCHW};
+ constexpr auto interpolation_policy = InterpolationPolicy::BILINEAR;
+
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx};
+ GpuWorkloadSketch sketch{&context};
+
+ const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
+
+ ResizeAttributes attributes{};
+ attributes.interpolation_policy(interpolation_policy);
+
+ Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
+ ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE_END() // Validate
+
+template <typename T>
+using DynamicFusionResizeFixture = DynamicFusionResizeValidationFixture<CLTensor, CLAccessor, GpuResize, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+
+const auto f32_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<float>())),
+ framework::dataset::make("DataType", DataType::F32));
+
+FIXTURE_DATA_TEST_CASE(Run,
+ DynamicFusionResizeFixture<float>,
+ framework::DatasetMode::ALL,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
+}
+
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+ DynamicFusionResizeFixture<float>,
+ framework::DatasetMode::ALL,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleAlignCornersSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
+}
+const auto f32_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<float>())),
+ framework::dataset::make("DataType", DataType::F32));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+ DynamicFusionResizeFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
+}
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+ DynamicFusionResizeFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleAlignCornersSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+const auto f16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<half>())),
+ framework::dataset::make("DataType", DataType::F16));
+FIXTURE_DATA_TEST_CASE(Run,
+ DynamicFusionResizeFixture<half>,
+ framework::DatasetMode::ALL,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+ DynamicFusionResizeFixture<half>,
+ framework::DatasetMode::ALL,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleAlignCornersSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
+}
+const auto f16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<half>())),
+ framework::dataset::make("DataType", DataType::F16));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+ DynamicFusionResizeFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+ DynamicFusionResizeFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleAlignCornersSamplingPolicySet))
+{
+ //Create valid region
+ TensorInfo src_info(_shape, 1, _data_type);
+ const ValidRegion valid_region =
+ calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+
+ // Validate output
+ validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
+}
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // Float
+
+TEST_SUITE_END() // RESIZE
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
new file mode 100644
index 0000000000..0134a7c11b
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+constexpr AbsoluteTolerance<float> tolerance_f32(1e-6f);
+constexpr AbsoluteTolerance<float> tolerance_f16(0.001f);
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(SIGMOID)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type
+ }),
+ framework::dataset::make("Expected", { true, true, false })),
+ input_info, expected)
+{
+ // Create a new workload sketch
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Fuse sigmoid
+ const ITensorInfo *src_info = context.create_tensor_info(input_info);
+
+ const bool res = static_cast<bool>(GpuSigmoid::validate_op(sketch, src_info));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionSigmoidOpFixture = DynamicFusionSigmoidValidationFixture<CLTensor, CLAccessor, GpuSigmoid, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionSigmoidOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionSigmoidOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionSigmoidOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionSigmoidOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionSigmoidOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionSigmoidOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+
+TEST_SUITE_END() // SIGMOID
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
new file mode 100644
index 0000000000..8f5a1ed14a
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
@@ -0,0 +1,219 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h"
+#include "tests/validation/Validation.h"
+
+using namespace arm_compute::experimental::dynamic_fusion;
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+/** Tolerance for float operations */
+RelativeTolerance<half> tolerance_f16(half(0.2));
+RelativeTolerance<float> tolerance_f32(0.001f);
+
+using framework::dataset::make;
+
+/// TODO: COMPMID-6713
+/// Softmax is not implemented in CKW. Therefore, the tests are DISABLED.
+/// Enable the tests when Softmax is implemented in CKW.
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(SOFTMAX)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED,
+ zip(
+ make("InputInfo", {
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching data types
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::S32), // Unsupported data type
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ }),
+ make("OutputInfo",{
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F16),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM16), // Unsupported data type
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ }),
+ make("beta", {
+ 1.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ }),
+ make("axis", {
+ 0,
+ 0,
+ 1, // Invalid as axis != 0
+ 0,
+ 0,
+ 0,
+ -3, // Invalid as axis != 0
+ 2, // Invalid as axis != 0
+ 1, // Invalid as axis != 0
+ -1, // Invalid as axis != 0
+ }),
+ make("Expected", { false, false, false, true, false, false, false, false, false, false})),
+ input_info, output_info, beta, axis, expected)
+{
+ // Create a new workload sketch
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ SoftmaxAttributes softmax_attr{};
+ softmax_attr.axis(axis).beta(beta).is_log_softmax(false);
+ ITensorInfo* src_info = context.create_tensor_info(input_info);
+ ITensorInfo* dst_info = context.create_tensor_info(output_info);
+ const bool res = static_cast<bool>(GpuSoftmax::validate_op(sketch, src_info, dst_info, softmax_attr));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+
+template <typename T>
+using DynamicFusionSoftmaxLayerFixture = DynamicFusionSoftmaxValidationFixture<CLTensor, CLAccessor, GpuSoftmax, T>;
+
+TEST_SUITE(FLOAT)
+TEST_SUITE(FP32)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerLargeShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+
+FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayer4DShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+TEST_SUITE(FP16)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerLargeShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+
+FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayer4DShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // FLOAT
+
+TEST_SUITE_END() // SOFTMAX
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
new file mode 100644
index 0000000000..c7ab1e717c
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSub.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/DynamicFusionDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+/* Synced with tests/validation/CL/ArithmeticSubtraction.cpp from the standard interface.
+ *
+ * Difference | Why the difference
+ * No quantized tests | Not supported yet
+ * No in place tests | Not supported yet
+ * No activation tests | Not needed in dynamic fusion interface
+ *
+ */
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(SUB)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U32), // Unsupported data type U32
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED), // Unsupported data type QASYMM8
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32), // Broadcast Y dimension is not allowed
+ TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::S16), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32), // Batching is allowed
+ }),
+ framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
+ TensorInfo(TensorShape(15U, 1U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape( 3U, 8U, 1U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
+ })),
+ framework::dataset::make("Expected", { true, false, false, false, false, false, false, false, true, true, false, false, true })),
+ input1_info, input2_info, expected)
+{
+ // Create a new workload sketch
+ auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Validate Elementwise Sub
+ auto lhs_info = context.create_tensor_info(input1_info);
+ auto rhs_info = context.create_tensor_info(input2_info);
+
+ bool res = bool(GpuSub::validate_op(sketch, lhs_info, rhs_info));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionCLSubFixture =
+ DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+
+template <typename T>
+using DynamicFusionCLSubBroadcastFixture =
+ DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+
+template <typename T>
+using DynamicFusionCLSubTwoOpsFixture =
+ DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLSubFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
+ DynamicFusionCLSubFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::LargeShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLSubBroadcastFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::TemporaryLimitedSmallShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
+ DynamicFusionCLSubBroadcastFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::TemporaryLimitedLargeShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(
+ RunSmallTwoOps,
+ DynamicFusionCLSubTwoOpsFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
+ framework::dataset::make("DataType", {DataType::F32})),
+ framework::dataset::make("InPlace", {false})),
+ framework::dataset::make("FuseTwoOps", {true})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionCLSubFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
+ DynamicFusionCLSubBroadcastFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::TemporaryLimitedSmallShapesBroadcast()),
+ framework::dataset::make("DataType", {DataType::F16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLSubFixture<int32_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::S32})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+
+TEST_SUITE(S16)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLSubFixture<int16_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::S16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ DynamicFusionCLSubFixture<int16_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::LargeShapes()),
+ framework::dataset::make("DataType", {DataType::S16})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // S16
+
+TEST_SUITE(U8)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ DynamicFusionCLSubFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+ datasets::SmallShapes()),
+ framework::dataset::make("DataType", {DataType::U8})),
+ framework::dataset::make("InPlace", {false})))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // U8
+
+TEST_SUITE_END() // SUB
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
new file mode 100644
index 0000000000..2560f3aab1
--- /dev/null
+++ b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuTanh.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+
+#include "tests/CL/CLAccessor.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+constexpr AbsoluteTolerance<float> tolerance_f32(0.00001f);
+constexpr AbsoluteTolerance<float> tolerance_f16(0.001f);
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(DYNAMIC_FUSION)
+TEST_SUITE(TANH)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type
+ }),
+ framework::dataset::make("Expected", { true, true, false })),
+ input_info, expected)
+{
+ // Create a new workload sketch
+ CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+ GpuWorkloadContext context{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
+
+ // Fuse tanh
+ const ITensorInfo* src_info = context.create_tensor_info(input_info);
+
+ const bool res = static_cast<bool>(GpuTanh::validate_op(sketch, src_info));
+ ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using DynamicFusionTanhOpFixture = DynamicFusionTanhValidationFixture<CLTensor, CLAccessor, GpuTanh, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionTanhOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionTanhOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionTanhOpFixture<half>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
+ DynamicFusionTanhOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
+ DynamicFusionTanhOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ ARM_COMPUTE_TEST_INFO("Currently 5D+ tensors are unsupported for this operation.");
+ framework::ARM_COMPUTE_PRINT_INFO();
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
+ DynamicFusionTanhOpFixture<float>,
+ framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+
+TEST_SUITE_END() // TANH
+TEST_SUITE_END() // DYNAMIC_FUSION
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute