From fdf56fb9d414a754e7cedfdc1351ab0ce2866a0c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 18 Jan 2024 16:10:46 +0000 Subject: Make GpuWorkloadContext own all tensor info objects * The tensor info objects created by calling create_tensor_info is now solely owned by the context object. The user only receives pointers to those objects. - Internally pointers to tensor info objects are used in various places. It's safer for dynamic fusion to manage these objects directly rather than relying on the users. - The validation test is updated to use the modified API. * Make various changes in dynamic fusion API to make it more friendly (e.g. making some of the objects moveable). Partially resolves: COMPMID-6707 Signed-off-by: Viet-Hoa Do Change-Id: Ifee70e53c05f8e7b72bf9ef123701ff291c5ee80 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10990 Tested-by: Arm Jenkins Reviewed-by: Jakub Sujak Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../validation/dynamic_fusion/gpu/Integration.cpp | 240 +++++++------ tests/validation/dynamic_fusion/gpu/cl/Add.cpp | 104 +++--- tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp | 40 ++- .../dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp | 161 +++++---- .../dynamic_fusion/gpu/cl/DirectConv2d.cpp | 47 ++- tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp | 313 ++++++++-------- tests/validation/dynamic_fusion/gpu/cl/Mul.cpp | 56 ++- tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp | 176 +++++---- tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp | 83 +++-- tests/validation/dynamic_fusion/gpu/cl/Resize.cpp | 398 +++++++++++++-------- tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp | 28 +- tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp | 12 +- tests/validation/dynamic_fusion/gpu/cl/Sub.cpp | 99 ++--- tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp | 28 +- 14 files changed, 1011 insertions(+), 774 deletions(-) (limited to 'tests/validation/dynamic_fusion/gpu') diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp index 89cca5cd66..bb9c008f01 100644 --- a/tests/validation/dynamic_fusion/gpu/Integration.cpp +++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -37,11 +37,10 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h" - #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h" + #include "tests/CL/CLAccessor.h" #include "tests/framework/Macros.h" -#include "tests/validation/Validation.h" #include "tests/validation/dynamic_fusion/Utils.h" #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ConvolutionLayer.h" @@ -50,6 +49,7 @@ #include "tests/validation/reference/ElementwiseOperations.h" #include "tests/validation/reference/Permute.h" #include "tests/validation/reference/PixelWiseMultiplication.h" +#include "tests/validation/Validation.h" using namespace arm_compute::experimental::dynamic_fusion; using namespace arm_compute::test::validation::utils; @@ -79,18 +79,18 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Fuse conv2d Conv2dAttributes conv2d_attr{}; - TensorInfo input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); - TensorInfo weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); + ITensorInfo *input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); + ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); - ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr); + ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr); - TensorInfo dst_info = context.create_tensor_info(); - GpuOutput::create_op(sketch, conv_out_info, &dst_info); + ITensorInfo *dst_info = context.create_tensor_info(); + GpuOutput::create_op(sketch, conv_out_info, dst_info); // Configure runtime ClWorkloadRuntime runtime; @@ -98,7 +98,7 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) // (Important) Allocate auxiliary tensor memory if there are any // Instead of using ACL allocated memory, the user can choose to import memory into the tensors - for(auto &data : runtime.get_auxiliary_tensors()) + for (auto &data : runtime.get_auxiliary_tensors()) { CLTensor *tensor = std::get<0>(data); TensorInfo info = std::get<1>(data); @@ -115,9 +115,9 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) CLTensor t_dst{}; // Initialize user tensors - t_input.allocator()->init(input_info); - t_weight.allocator()->init(weight_info); - t_dst.allocator()->init(dst_info); + t_input.allocator()->init(*input_info); + t_weight.allocator()->init(*weight_info); + t_dst.allocator()->init(*dst_info); // Allocate and fill user tensors // Instead of using ACL allocator, the user can choose to import memory into the tensors @@ -128,12 +128,12 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) fill(CLAccessor(t_weight), 1, library.get()); // Run runtime - runtime.run({ &t_input, &t_weight, &t_dst }); + runtime.run({&t_input, &t_weight, &t_dst}); // Create reference - SimpleTensor ref_t_input{ t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; - SimpleTensor ref_t_weight{ t_weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; - SimpleTensor ref_t_bias_placeholder{ t_dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_t_input{t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC}; + SimpleTensor ref_t_weight{t_weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC}; + SimpleTensor ref_t_bias_placeholder{t_dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC}; // Fill reference fill(ref_t_input, 0, library.get()); @@ -145,12 +145,15 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) auto t_dst_shape_nchw = t_dst_shape; permute(t_dst_shape_nchw, PermutationVector(1U, 2U, 0U)); - PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left, conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom, + PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left, + conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom, DimensionRoundingType{}); - auto ref_t_dst_nchw = reference::convolution_layer(ref_t_input_nchw, ref_t_weight_nchw, ref_t_bias_placeholder_nchw, t_dst_shape_nchw, legacy_pad_stride, conv2d_attr.dilation()); - const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U)); + auto ref_t_dst_nchw = reference::convolution_layer(ref_t_input_nchw, ref_t_weight_nchw, ref_t_bias_placeholder_nchw, + t_dst_shape_nchw, legacy_pad_stride, conv2d_attr.dilation()); + const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U)); - RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ + RelativeTolerance tolerance_f32( + 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32); } #endif // ACL_INTERNAL_TEST_CKW_IN_DF @@ -167,20 +170,20 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo out_0_info = context.create_tensor_info(); - TensorInfo out_1_info = context.create_tensor_info(); + ITensorInfo *out_0_info = context.create_tensor_info(); + ITensorInfo *out_1_info = context.create_tensor_info(); - ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info); - GpuOutput::create_op(sketch, ans_0_info, &out_0_info); - ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, &in_2_info); - GpuOutput::create_op(sketch, ans_1_info, &out_1_info); + ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info); + GpuOutput::create_op(sketch, ans_0_info, out_0_info); + ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info); + GpuOutput::create_op(sketch, ans_1_info, out_1_info); // Configure runtime ClWorkloadRuntime runtime; @@ -188,7 +191,7 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) // (Important) Allocate auxiliary tensor memory if there are any // Instead of using ACL allocated memory, the user can choose to import memory into the tensors - for(auto &data : runtime.get_auxiliary_tensors()) + for (auto &data : runtime.get_auxiliary_tensors()) { CLTensor *tensor = std::get<0>(data); TensorInfo info = std::get<1>(data); @@ -208,12 +211,12 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) CLTensor t_out_1{}; // Initialize user tensors - t_in_0.allocator()->init(in_0_info); - t_in_1.allocator()->init(in_1_info); - t_in_2.allocator()->init(in_2_info); + t_in_0.allocator()->init(*in_0_info); + t_in_1.allocator()->init(*in_1_info); + t_in_2.allocator()->init(*in_2_info); - t_out_0.allocator()->init(out_0_info); - t_out_1.allocator()->init(out_1_info); + t_out_0.allocator()->init(*out_0_info); + t_out_1.allocator()->init(*out_1_info); // Allocate and fill user tensors // Instead of using ACL allocator, the user can choose to import memory into the tensors @@ -229,15 +232,15 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) fill(CLAccessor(t_in_2), 2, library.get()); // Run runtime - runtime.run({ &t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1 }); + runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1}); // Create reference - SimpleTensor ref_t_in_0{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_in_1{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_in_2{ t_input_shape, data_type, 1, QuantizationInfo() }; + SimpleTensor ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()}; - SimpleTensor ref_t_out_0{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_out_1{ t_input_shape, data_type, 1, QuantizationInfo() }; + SimpleTensor ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_out_1{t_input_shape, data_type, 1, QuantizationInfo()}; // Fill reference fill(ref_t_in_0, 0, library.get()); @@ -245,9 +248,11 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) fill(ref_t_in_2, 2, library.get()); reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP); - reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_out_1, ConvertPolicy::WRAP); + reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_out_1, + ConvertPolicy::WRAP); - RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ + RelativeTolerance tolerance_f32( + 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_f32); validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_f32); } @@ -264,15 +269,15 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); + ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo out_0_info = context.create_tensor_info(); - TensorInfo out_1_info = context.create_tensor_info(); + ITensorInfo *out_0_info = context.create_tensor_info(); + ITensorInfo *out_1_info = context.create_tensor_info(); CastAttributes cast_0_attr; cast_0_attr.data_type(DataType::S32).convert_policy(ConvertPolicy::SATURATE); @@ -280,12 +285,12 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) CastAttributes cast_1_attr; cast_1_attr.data_type(DataType::F32).convert_policy(ConvertPolicy::SATURATE); - ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info); - GpuOutput::create_op(sketch, ans_0_info, &out_0_info); - ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, &in_2_info); + ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info); + GpuOutput::create_op(sketch, ans_0_info, out_0_info); + ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info); ITensorInfo *ans_2_info = GpuCast::create_op(sketch, ans_1_info, cast_0_attr); ITensorInfo *ans_3_info = GpuCast::create_op(sketch, ans_2_info, cast_1_attr); - GpuOutput::create_op(sketch, ans_3_info, &out_1_info); + GpuOutput::create_op(sketch, ans_3_info, out_1_info); // Configure runtime ClWorkloadRuntime runtime; @@ -293,7 +298,7 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) // (Important) Allocate auxiliary tensor memory if there are any // Instead of using ACL allocated memory, the user can choose to import memory into the tensors - for(auto &data : runtime.get_auxiliary_tensors()) + for (auto &data : runtime.get_auxiliary_tensors()) { CLTensor *tensor = std::get<0>(data); TensorInfo info = std::get<1>(data); @@ -313,12 +318,12 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) CLTensor t_out_1{}; // Initialize user tensors - t_in_0.allocator()->init(in_0_info); - t_in_1.allocator()->init(in_1_info); - t_in_2.allocator()->init(in_2_info); + t_in_0.allocator()->init(*in_0_info); + t_in_1.allocator()->init(*in_1_info); + t_in_2.allocator()->init(*in_2_info); - t_out_0.allocator()->init(out_0_info); - t_out_1.allocator()->init(out_1_info); + t_out_0.allocator()->init(*out_0_info); + t_out_1.allocator()->init(*out_1_info); // Allocate and fill user tensors // Instead of using ACL allocator, the user can choose to import memory into the tensors @@ -334,15 +339,15 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) fill(CLAccessor(t_in_2), 2, library.get()); // Run runtime - runtime.run({ &t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1 }); + runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1}); // Create reference - SimpleTensor ref_t_in_0{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_in_1{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_in_2{ t_input_shape, data_type, 1, QuantizationInfo() }; + SimpleTensor ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()}; - SimpleTensor ref_t_out_0{ t_input_shape, data_type, 1, QuantizationInfo() }; - SimpleTensor ref_t_ans_1{ t_input_shape, data_type, 1, QuantizationInfo() }; + SimpleTensor ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()}; + SimpleTensor ref_t_ans_1{t_input_shape, data_type, 1, QuantizationInfo()}; // Fill reference fill(ref_t_in_0, 0, library.get()); @@ -350,9 +355,12 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) fill(ref_t_in_2, 2, library.get()); reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP); - reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_ans_1, ConvertPolicy::WRAP); - const auto ref_t_ans_2 = reference::depth_convert(ref_t_ans_1, DataType::S32, ConvertPolicy::SATURATE, 0); - const auto ref_t_out_1 = reference::depth_convert(ref_t_ans_2, DataType::F32, ConvertPolicy::SATURATE, 0); + reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_ans_1, + ConvertPolicy::WRAP); + const auto ref_t_ans_2 = + reference::depth_convert(ref_t_ans_1, DataType::S32, ConvertPolicy::SATURATE, 0); + const auto ref_t_out_1 = + reference::depth_convert(ref_t_ans_2, DataType::F32, ConvertPolicy::SATURATE, 0); RelativeTolerance tolerance_add_f32(0.001f); AbsoluteTolerance tolerance_cast_f32(1.0f); @@ -436,20 +444,22 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) Conv2dAttributes conv2d_attr; auto tensor1_info = context.create_tensor_info(conv2d_wei_shape, 1, DataType::F32, DataLayout::NHWC); auto tensor2_info = context.create_tensor_info(conv2d_bia_shape, 1, DataType::F32, DataLayout::NHWC); - ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr), framework::LogLevel::ERRORS); - auto ans_info = GpuConv2d::create_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr); + ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr), + framework::LogLevel::ERRORS); + auto ans_info = GpuConv2d::create_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr); ARM_COMPUTE_EXPECT(GpuSigmoid::validate_op(sketch0, ans_info), framework::LogLevel::ERRORS); ans_info = GpuSigmoid::create_op(sketch0, ans_info); DepthwiseConv2dAttributes dwc_attr; - auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC); - auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC); - ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS); + auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC); + auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC); + ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, tensor3_info, tensor4_info, dwc_attr), + framework::LogLevel::ERRORS); auto tensor5_info = context.create_tensor_info(); - ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, &tensor5_info), framework::LogLevel::ERRORS); - GpuOutput::create_op(sketch0, ans_info, &tensor5_info); + ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, tensor5_info), framework::LogLevel::ERRORS); + GpuOutput::create_op(sketch0, ans_info, tensor5_info); // Create the first workload runtime. ClWorkloadRuntime runtime0; @@ -458,15 +468,16 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) // Create the second sketch: dwc + sigmoid + output. GpuWorkloadSketch sketch1(&context); - ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS); - ans_info = GpuDepthwiseConv2d::create_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr); + ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr), + framework::LogLevel::ERRORS); + ans_info = GpuDepthwiseConv2d::create_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr); - ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, &tensor2_info), framework::LogLevel::ERRORS); - ans_info = GpuMul::create_op(sketch1, ans_info, &tensor2_info); + ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, tensor2_info), framework::LogLevel::ERRORS); + ans_info = GpuMul::create_op(sketch1, ans_info, tensor2_info); auto tensor6_info = context.create_tensor_info(); - ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, &tensor6_info), framework::LogLevel::ERRORS); - GpuOutput::create_op(sketch1, ans_info, &tensor6_info); + ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, tensor6_info), framework::LogLevel::ERRORS); + GpuOutput::create_op(sketch1, ans_info, tensor6_info); // Create the second workload runtime. ClWorkloadRuntime runtime1; @@ -481,13 +492,13 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) CLTensor tensor5; CLTensor tensor6; - tensor0.allocator()->init(tensor0_info); - tensor1.allocator()->init(tensor1_info); - tensor2.allocator()->init(tensor2_info); - tensor3.allocator()->init(tensor3_info); - tensor4.allocator()->init(tensor4_info); - tensor5.allocator()->init(tensor5_info); - tensor6.allocator()->init(tensor6_info); + tensor0.allocator()->init(*tensor0_info); + tensor1.allocator()->init(*tensor1_info); + tensor2.allocator()->init(*tensor2_info); + tensor3.allocator()->init(*tensor3_info); + tensor4.allocator()->init(*tensor4_info); + tensor5.allocator()->init(*tensor5_info); + tensor6.allocator()->init(*tensor6_info); tensor0.allocator()->allocate(); tensor1.allocator()->allocate(); @@ -498,7 +509,7 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) tensor6.allocator()->allocate(); // Allocate the auxiliary tensors. - for(auto &data : runtime0.get_auxiliary_tensors()) + for (auto &data : runtime0.get_auxiliary_tensors()) { auto tensor = std::get<0>(data); auto &tensor_info = std::get<1>(data); @@ -508,7 +519,7 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) tensor->allocator()->allocate(); } - for(auto &data : runtime1.get_auxiliary_tensors()) + for (auto &data : runtime1.get_auxiliary_tensors()) { auto tensor = std::get<0>(data); auto &tensor_info = std::get<1>(data); @@ -526,8 +537,8 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) fill(CLAccessor(tensor4), 4, library.get()); // Run each runtime. - runtime0.run({ &tensor0, &tensor1, &tensor2, &tensor5 }); - runtime1.run({ &tensor5, &tensor3, &tensor4, &tensor2, &tensor6 }); + runtime0.run({&tensor0, &tensor1, &tensor2, &tensor5}); + runtime1.run({&tensor5, &tensor3, &tensor4, &tensor2, &tensor6}); // Compute the reference result. SimpleTensor ref_conv2d_src(conv2d_src_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); @@ -549,18 +560,22 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) const auto ref_conv2d_src_nchw = reference::permute(ref_conv2d_src, nhwc_to_nchw); const auto ref_conv2d_wei_nchw = reference::permute(ref_conv2d_wei, nhwc_to_nchw); const auto ref_conv2d_bia_nchw = reference::permute(ref_conv2d_bia, nhwc_to_nchw); - const auto ref_conv2d_dst_nchw = reference::convolution_layer(ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo()); + const auto ref_conv2d_dst_nchw = reference::convolution_layer( + ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo()); - const auto ref_sigmoid_dst_nchw = reference::activation_layer(ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + const auto ref_sigmoid_dst_nchw = reference::activation_layer( + ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); auto dwc_dst_shape_nchw = dwc_dst_shape; permute(dwc_dst_shape_nchw, nhwc_to_nchw); const auto ref_dwc_wei_nchw = reference::permute(ref_dwc_wei, nhwc_to_nchw); const auto ref_dwc_bia_nchw = reference::permute(ref_dwc_bia, nhwc_to_nchw); - const auto ref_dwc_dst_nchw = reference::depthwise_convolution(ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1); + const auto ref_dwc_dst_nchw = reference::depthwise_convolution( + ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1); - const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication(ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, - DataType::F32); + const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication( + ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, + DataType::F32); constexpr RelativeTolerance tolerance(0.001f); validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance); @@ -587,34 +602,35 @@ TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Create tensor infos - TensorInfo input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); - TensorInfo weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); + ITensorInfo *input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); + ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); ITensorInfo *dst_info; // Fuse conv2d into the workload { // Validate operator - const Status success = GpuConv2d::validate_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr); + const Status success = GpuConv2d::validate_op(sketch, input_info, weight_info, nullptr, conv2d_attr); ARM_COMPUTE_EXPECT(bool(success), framework::LogLevel::ERRORS); - dst_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr); + dst_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr); } // Create tensor infos - TensorInfo weight_info_2 = context.create_tensor_info(t_weight_info); + ITensorInfo *weight_info_2 = context.create_tensor_info(t_weight_info); // Fuse conv2d into the workload { // Validate operator, should fail - const Status success = GpuConv2d::validate_op(sketch, dst_info, &weight_info_2, nullptr, conv2d_attr); - const auto expected_error_str = "Operator fusion test failed. This operator cannot be fused into the workload"; + const Status success = GpuConv2d::validate_op(sketch, dst_info, weight_info_2, nullptr, conv2d_attr); + const auto expected_error_str = "Operator fusion test failed. This operator cannot be fused into the workload"; ARM_COMPUTE_EXPECT(!bool(success), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT((success.error_description().find(expected_error_str) != std::string::npos), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((success.error_description().find(expected_error_str) != std::string::npos), + framework::LogLevel::ERRORS); } } TEST_SUITE_END() // Invalid_Fusion_Should_Fail diff --git a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp index 09a8f3fe39..a358d47bdd 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,14 +29,13 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h" #include "tests/CL/CLAccessor.h" -#include "tests/framework/Fixture.h" -#include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" - #include "tests/datasets/DynamicFusionDataset.h" #include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -97,32 +96,36 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( auto lhs_info = context.create_tensor_info(input1_info); auto rhs_info = context.create_tensor_info(input2_info); - bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info)); + bool res = bool(GpuAdd::validate_op(sketch, lhs_info, rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* -constexpr AbsoluteTolerance tolerance_f(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 and DataType::F16 */ -constexpr float tolerance_num = 0.0001f; /**< Tolerance number */ +constexpr AbsoluteTolerance tolerance_f( + 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 and DataType::F16 */ +constexpr float tolerance_num = 0.0001f; /**< Tolerance number */ template -using DynamicFusionCLAddFixture = DynamicFusionGpuElementwiseBinaryOneOpValidationFixture; +using DynamicFusionCLAddFixture = + DynamicFusionGpuElementwiseBinaryOneOpValidationFixture; template -using DynamicFusionCLAddBroadcastFixture = DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture; +using DynamicFusionCLAddBroadcastFixture = + DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture; template -using DynamicFusionCLAddTwoOpsFixture = DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture; +using DynamicFusionCLAddTwoOpsFixture = + DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture; TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLAddFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f); @@ -130,10 +133,10 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunLargeOneOp, DynamicFusionCLAddFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::LargeShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f); @@ -141,10 +144,10 @@ FIXTURE_DATA_TEST_CASE(RunLargeOneOp, FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLAddBroadcastFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::TemporaryLimitedSmallShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f); @@ -153,22 +156,23 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp, DynamicFusionCLAddBroadcastFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::TemporaryLimitedLargeShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f); } -FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, - DynamicFusionCLAddTwoOpsFixture, - framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), - datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false })), - framework::dataset::make("FuseTwoOps", { true }))) +FIXTURE_DATA_TEST_CASE( + RunSmallTwoOps, + DynamicFusionCLAddTwoOpsFixture, + framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), + datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()), + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false})), + framework::dataset::make("FuseTwoOps", {true}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f); @@ -179,10 +183,10 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLAddFixture, framework::DatasetMode::ALL, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num); @@ -191,10 +195,10 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLAddBroadcastFixture, framework::DatasetMode::ALL, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::TemporaryLimitedSmallShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num); @@ -206,10 +210,10 @@ TEST_SUITE(S32) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLAddFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::S32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -220,10 +224,10 @@ TEST_SUITE(S16) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLAddFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::S16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -231,10 +235,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionCLAddFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::LargeShapes()), - framework::dataset::make("DataType", { DataType::S16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -245,10 +249,10 @@ TEST_SUITE(U8) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLAddFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::U8 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::U8})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp index 285c0d6608..cef8b87c3f 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,10 +29,10 @@ #include "tests/CL/CLAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" -#include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -73,13 +73,13 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( GpuWorkloadSketch sketch{ &context }; // Fuse Clamp - const TensorInfo src_info = context.create_tensor_info(input_info); + const ITensorInfo* src_info = context.create_tensor_info(input_info); ClampAttributes attributes {}; attributes.min_val(min_val) .max_val(max_val); - const bool res = static_cast(GpuClamp::validate_op(sketch, &src_info, attributes)); + const bool res = static_cast(GpuClamp::validate_op(sketch, src_info, attributes)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -94,8 +94,9 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.6f) })), - framework::dataset::make("Fuse", { false })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})), + framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -106,8 +107,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::Small5dShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.6f) })), - framework::dataset::make("Fuse", { false })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})), + framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -119,8 +121,9 @@ FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.2f).max_val(0.4f) })), - framework::dataset::make("Fuse", { true })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.2f).max_val(0.4f)})), + framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -134,8 +137,9 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.3f).max_val(0.7f) })), - framework::dataset::make("Fuse", { false })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})), + framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -146,8 +150,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::Small5dShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.3f).max_val(0.7f) })), - framework::dataset::make("Fuse", { false })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})), + framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -159,8 +164,9 @@ FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionClampOpFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.9f) })), - framework::dataset::make("Fuse", { true })), + framework::dataset::make( + "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.9f)})), + framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F32))) { // Validate output diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp index aec1306a31..40e1ea8929 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,11 +28,11 @@ #include "tests/datasets/DepthwiseConvolutionLayerDataset.h" #include "tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h" #include "tests/framework/Asserts.h" +#include "tests/framework/datasets/Datasets.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -40,16 +40,18 @@ namespace test { namespace validation { -const auto depth_multipliers = framework::dataset::make("DepthMultiplier", { 1U, 4U }); -const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 5, 8 }); +const auto depth_multipliers = framework::dataset::make("DepthMultiplier", {1U, 4U}); +const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", {1, 2, 5, 8}); TEST_SUITE(CL) TEST_SUITE(DYNAMIC_FUSION) TEST_SUITE(DEPTHWISE_CONV2D) -RelativeTolerance tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ -RelativeTolerance tolerance_f16(half_float::half(0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ -constexpr float tolerance_num = 0.02f; /**< Tolerance number */ +RelativeTolerance tolerance_f32( + 0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +RelativeTolerance tolerance_f16(half_float::half( + 0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr float tolerance_num = 0.02f; /**< Tolerance number */ // *INDENT-OFF* // clang-format off @@ -245,9 +247,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); - const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info); - const TensorInfo sketch_biases_info = context.create_tensor_info(biases_info); + const ITensorInfo* sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info); + const ITensorInfo* sketch_biases_info = context.create_tensor_info(biases_info); DepthwiseConv2dAttributes attributes {}; attributes.pad(padding) @@ -255,7 +257,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi .dilation(dilation) .depth_multiplier(depth_multiplier); - const Status status = GpuDepthwiseConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, attributes); + const Status status = GpuDepthwiseConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, attributes); const bool res = bool(status); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } @@ -263,40 +265,50 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi // *INDENT-ON* template -using DynamicFusionGpuDepthwiseConv2dFixture = DynamicFusionGpuDepthwiseConv2dValidationFixture; +using DynamicFusionGpuDepthwiseConv2dFixture = + DynamicFusionGpuDepthwiseConv2dValidationFixture; TEST_SUITE(Float) TEST_SUITE(FP16) TEST_SUITE(W3x3) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, - combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), - depth_multipliers), +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", DataLayout::NHWC))) { validate(CLAccessor(_target), _reference, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(), - large_depth_multipliers), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", DataLayout::NHWC))) +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(), + large_depth_multipliers), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", DataLayout::NHWC))) { validate(CLAccessor(_target), _reference, tolerance_f16); } #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel TEST_SUITE(Dilation) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), - depth_multipliers), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), + depth_multipliers), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16); } @@ -305,34 +317,44 @@ TEST_SUITE_END() // Dilation TEST_SUITE_END() // W3x3 TEST_SUITE(Generic) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), - depth_multipliers), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(), - large_depth_multipliers), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(), + large_depth_multipliers), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); } #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel TEST_SUITE(Dilation) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(), - depth_multipliers), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(), + depth_multipliers), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); } @@ -343,15 +365,18 @@ TEST_SUITE_END() // FP16 TEST_SUITE(FP32) TEST_SUITE(W3x3) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, - combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), - depth_multipliers), +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", DataLayout::NHWC))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F32)), @@ -363,7 +388,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel TEST_SUITE(Dilation) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), depth_multipliers), framework::dataset::make("DataType", DataType::F32)), @@ -371,7 +398,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F32)), @@ -384,47 +413,57 @@ TEST_SUITE_END() // Dilation TEST_SUITE_END() // W3x3 TEST_SUITE(Generic) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, - combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), - depth_multipliers), +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLargeKernelSize, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, +FIXTURE_DATA_TEST_CASE(RunLargeKernelSize, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, combine(combine(combine(datasets::LargeKernelSizeDepthwiseConvolutionLayerNHWCDataset(), - framework::dataset::make("DepthMultiplier", { 1 })), + framework::dataset::make("DepthMultiplier", {1})), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel TEST_SUITE(Dilation) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(), - depth_multipliers), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(), + depth_multipliers), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuDepthwiseConv2dFixture, + framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(), large_depth_multipliers), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC }))) + framework::dataset::make("DataLayout", {DataLayout::NHWC}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp index bae8cbf868..dae550003e 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,14 +24,13 @@ #include "tests/AssetsLibrary.h" #include "tests/CL/CLAccessor.h" +#include "tests/datasets/SmallConvolutionLayerDataset.h" +#include "tests/framework/datasets/Datasets.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" -#include "tests/validation/reference/ConvolutionLayer.h" - -#include "tests/datasets/SmallConvolutionLayerDataset.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h" +#include "tests/validation/reference/ConvolutionLayer.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -43,10 +42,12 @@ namespace { /** Tolerances from tests/validation/CL/DirectConvolutionLayer.cpp */ -RelativeTolerance tolerance_f32(0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ -RelativeTolerance tolerance_f16(half_float::half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ -constexpr float abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/ -constexpr float tolerance_num = 0.07f; /**< Tolerance number */ +RelativeTolerance tolerance_f32( + 0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +RelativeTolerance tolerance_f16(half_float::half( + 0.2)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr float abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/ +constexpr float tolerance_num = 0.07f; /**< Tolerance number */ } // namespace TEST_SUITE(CL) @@ -69,8 +70,13 @@ TEST_SUITE(CONV2D) template using DynamicFusionGpuConv2dFixture = DynamicFusionGpuConv2dValidationFixture; TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), framework::dataset::make("QuantizationInfo", QuantizationInfo()))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", {DataLayout::NHWC})), + framework::dataset::make("QuantizationInfo", QuantizationInfo()))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -78,8 +84,13 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture, framework TEST_SUITE_END() // FP32 TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), framework::dataset::make("QuantizationInfo", QuantizationInfo()))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuConv2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", {DataLayout::NHWC})), + framework::dataset::make("QuantizationInfo", QuantizationInfo()))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); @@ -156,10 +167,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( auto context = GpuWorkloadContext{ &cl_compile_ctx }; GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); - const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info); - const TensorInfo sketch_biases_info = context.create_tensor_info(biases_info); - bool is_valid = bool(GpuConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, conv2d_attrs)); + const ITensorInfo* sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info); + const ITensorInfo* sketch_biases_info = context.create_tensor_info(biases_info); + bool is_valid = bool(GpuConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, conv2d_attrs)); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } template diff --git a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp index 38c3a0ca0e..d714a2f70c 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,16 +24,15 @@ #ifdef ACL_INTERNAL_TEST_CKW_IN_DF #include "tests/AssetsLibrary.h" #include "tests/CL/CLAccessor.h" -#include "tests/framework/Fixture.h" -#include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" #include "tests/datasets/LargeMatMulDataset.h" #include "tests/datasets/SmallMatMulDataset.h" -#include "tests/validation/Validation.h" -#include "tests/validation/reference/Permute.h" -#include "tests/validation/reference/GEMM.h" - +#include "tests/framework/datasets/Datasets.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h" +#include "tests/validation/reference/GEMM.h" +#include "tests/validation/reference/Permute.h" +#include "tests/validation/Validation.h" #include @@ -45,35 +44,37 @@ namespace validation { namespace { - RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ -constexpr float abs_tolerance_f32( +RelativeTolerance tolerance_f32( + 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ +constexpr float abs_tolerance_f32( 0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */ constexpr float abs_tolerance_f16( - 0.001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16 data types in case using relative tolerance fails because of small values */ - RelativeTolerance tolerance_f16(half(0.02)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ -} + 0.001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16 data types in case using relative tolerance fails because of small values */ +RelativeTolerance tolerance_f16(half( + 0.02)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ +} // namespace /** M0 values to test --precommit*/ -const auto m0_values_precommit = framework::dataset::make("M0", { 1, 3 }); +const auto m0_values_precommit = framework::dataset::make("M0", {1, 3}); /** N0 values to test --precommit*/ -const auto n0_values_precommit = framework::dataset::make("N0", { 1, 2, 4 }); +const auto n0_values_precommit = framework::dataset::make("N0", {1, 2, 4}); /** K0 values to test --precommit*/ -const auto k0_values_precommit = framework::dataset::make("K0", { 1, 2, 3 }); +const auto k0_values_precommit = framework::dataset::make("K0", {1, 2, 3}); /** M0 values to test --nightly*/ -const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", { 1, 2, 3, 4, 5, 6, 7, 8 }); -const auto m0_values_nightly_lhs_t = framework::dataset::make("M0", { 1, 2, 3, 4, 8 }); +const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", {1, 2, 3, 4, 5, 6, 7, 8}); +const auto m0_values_nightly_lhs_t = framework::dataset::make("M0", {1, 2, 3, 4, 8}); /** N0 values to test --nightly*/ -const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", { 1, 2, 3, 4, 8, 16 }); -const auto n0_values_nightly_rhs_t = framework::dataset::make("N0", { 1, 2, 3, 4, 8 }); +const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", {1, 2, 3, 4, 8, 16}); +const auto n0_values_nightly_rhs_t = framework::dataset::make("N0", {1, 2, 3, 4, 8}); /** K0 values to test --nightly*/ -const auto k0_values_nightly_lhs_nt_rhs_nt = framework::dataset::make("K0", { 1, 2, 3, 4, 8, 16 }); -const auto k0_values_nightly_rhs_t = framework::dataset::make("K0", { 1, 2, 3, 4, 8 }); -const auto k0_values_nightly_lhs_t_rhs_nt = framework::dataset::make("K0", { 1, 2, 3, 4, 5, 6, 7, 8 }); +const auto k0_values_nightly_lhs_nt_rhs_nt = framework::dataset::make("K0", {1, 2, 3, 4, 8, 16}); +const auto k0_values_nightly_rhs_t = framework::dataset::make("K0", {1, 2, 3, 4, 8}); +const auto k0_values_nightly_lhs_t_rhs_nt = framework::dataset::make("K0", {1, 2, 3, 4, 5, 6, 7, 8}); TEST_SUITE(CL) TEST_SUITE(DYNAMIC_FUSION) @@ -85,45 +86,43 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL) { using MatMulConfigurationPair = std::pair; - const std::vector supported_block_sizes = - { + const std::vector supported_block_sizes = { // MatMulKernelInfo(adj_lhs, adj_rhs, M0, N0, K0, export_rhs_to_cl_image = false) // Lhs not-transposed, Rhs transposed - { MatMulKernelInfo(false, true, 0, 1, 1), false }, // M0 should be > 0 - { MatMulKernelInfo(false, true, 3, 11, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16} - { MatMulKernelInfo(false, true, 3, 7, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16} - { MatMulKernelInfo(false, true, 3, 3, 12), false }, // K0 not in {1, 2, 3, 4, 8, 16} - { MatMulKernelInfo(false, true, 3, 3, 6), false }, // K0 not in {1, 2, 3, 4, 8, 16} - { MatMulKernelInfo(false, true, 5, 1, 2), true }, - { MatMulKernelInfo(false, true, 3, 3, 3), true }, - { MatMulKernelInfo(false, true, 2, 4, 8), true }, + {MatMulKernelInfo(false, true, 0, 1, 1), false}, // M0 should be > 0 + {MatMulKernelInfo(false, true, 3, 11, 1), false}, // N0 not in {1, 2, 3, 4, 8, 16} + {MatMulKernelInfo(false, true, 3, 7, 1), false}, // N0 not in {1, 2, 3, 4, 8, 16} + {MatMulKernelInfo(false, true, 3, 3, 12), false}, // K0 not in {1, 2, 3, 4, 8, 16} + {MatMulKernelInfo(false, true, 3, 3, 6), false}, // K0 not in {1, 2, 3, 4, 8, 16} + {MatMulKernelInfo(false, true, 5, 1, 2), true}, {MatMulKernelInfo(false, true, 3, 3, 3), true}, + {MatMulKernelInfo(false, true, 2, 4, 8), true}, }; // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Set big enough shapes so that block sizes are not truncated. Also, set all dimensions equal // so that it doesn't fail for different NT/T configurations. We aim to test the block sizes here, // not the shapes themselves. - const TensorInfo lhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32)); - const TensorInfo rhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32)); + const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32)); + const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32)); - for(auto &pair : supported_block_sizes) + for (auto &pair : supported_block_sizes) { - MatMulAttributes matmul_attr {}; + MatMulAttributes matmul_attr{}; matmul_attr.adj_lhs(pair.first.adj_lhs); matmul_attr.adj_rhs(pair.first.adj_rhs); - GpuMatMulSettings matmul_settings {}; + GpuMatMulSettings matmul_settings{}; matmul_settings.m0(pair.first.m0); matmul_settings.n0(pair.first.n0); matmul_settings.k0(pair.first.k0); - Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings); + Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings); ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS); } } @@ -132,117 +131,110 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL) { // Create a sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations - using ShapeConfigurationTuple = std::tuple; - const std::vector shape_configurations = - { - { TensorShape(5U, 1U), TensorShape(3U, 5U), true }, - { TensorShape(10U, 12U), TensorShape(3U, 10U), true }, - { TensorShape(8U, 4U), TensorShape(2U, 8U), true }, - { TensorShape(8U, 4U), TensorShape(2U, 5U), false }, // Mismatch in the K dimension - { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension - { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true }, - { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // no batch broadcasting - { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // mismatch in batch dimension + using ShapeConfigurationTuple = std::tuple; + const std::vector shape_configurations = { + {TensorShape(5U, 1U), TensorShape(3U, 5U), true}, + {TensorShape(10U, 12U), TensorShape(3U, 10U), true}, + {TensorShape(8U, 4U), TensorShape(2U, 8U), true}, + {TensorShape(8U, 4U), TensorShape(2U, 5U), false}, // Mismatch in the K dimension + {TensorShape(5U, 0U), TensorShape(2U, 5U), false}, // Invalid dimension + {TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true}, + {TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false}, // no batch broadcasting + {TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), + false}, // mismatch in batch dimension }; - for(auto &tuple : shape_configurations) + for (auto &tuple : shape_configurations) { const bool expected = std::get<2>(tuple); - for(bool adj_lhs : - { - false - }) + for (bool adj_lhs : {false}) { - for(bool adj_rhs : - { - true - }) + for (bool adj_rhs : {true}) { TensorShape lhs_shape = std::get<0>(tuple); TensorShape rhs_shape = std::get<1>(tuple); - if(adj_lhs) + if (adj_lhs) { permute(lhs_shape, PermutationVector(1U, 0U)); } - if(adj_rhs) + if (adj_rhs) { permute(rhs_shape, PermutationVector(1U, 0U)); } - const TensorInfo lhs_info = context.create_tensor_info(TensorInfo(lhs_shape, 1, DataType::F32)); - const TensorInfo rhs_info = context.create_tensor_info(TensorInfo(rhs_shape, 1, DataType::F32)); + const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(lhs_shape, 1, DataType::F32)); + const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(rhs_shape, 1, DataType::F32)); - MatMulAttributes matmul_attr {}; + MatMulAttributes matmul_attr{}; matmul_attr.adj_lhs(adj_lhs); matmul_attr.adj_rhs(adj_rhs); - GpuMatMulSettings matmul_settings {}; + GpuMatMulSettings matmul_settings{}; matmul_settings.m0(1); matmul_settings.n0(1); matmul_settings.k0(1); - Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings); + Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings); ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } } } } - TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL) { // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations using DataTypeConfigurationTuple = std::tuple; - const std::vector data_type_configurations = - { - { DataType::F32, DataType::F32, DataType::F32, true }, - { DataType::F16, DataType::F16, DataType::F16, true }, - { DataType::F16, DataType::F32, DataType::F32, false }, // no mixed precision - { DataType::F64, DataType::F64, DataType::F64, false }, // no double precision - { DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false }, // no quantized types - { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false }, // no quantized types - { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, false }, // no quantized types - { DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false }, // no quantized types - { DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false }, // no quantized types - { DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false }, // no quantized types - { DataType::S64, DataType::S64, DataType::S64, false }, // no integral types - { DataType::S32, DataType::S32, DataType::S32, false }, // no integral types - { DataType::S16, DataType::S16, DataType::S16, false }, // no integral types - { DataType::S8, DataType::S8, DataType::S8, false }, // no integral types - { DataType::U64, DataType::U64, DataType::U64, false }, // no integral types - { DataType::U32, DataType::U32, DataType::U32, false }, // no integral types - { DataType::U16, DataType::U16, DataType::U16, false }, // no integral types - { DataType::U8, DataType::U8, DataType::U8, false }, // no integral types + const std::vector data_type_configurations = { + {DataType::F32, DataType::F32, DataType::F32, true}, + {DataType::F16, DataType::F16, DataType::F16, true}, + {DataType::F16, DataType::F32, DataType::F32, false}, // no mixed precision + {DataType::F64, DataType::F64, DataType::F64, false}, // no double precision + {DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false}, // no quantized types + {DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false}, // no quantized types + {DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, + false}, // no quantized types + {DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false}, // no quantized types + {DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false}, // no quantized types + {DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false}, // no quantized types + {DataType::S64, DataType::S64, DataType::S64, false}, // no integral types + {DataType::S32, DataType::S32, DataType::S32, false}, // no integral types + {DataType::S16, DataType::S16, DataType::S16, false}, // no integral types + {DataType::S8, DataType::S8, DataType::S8, false}, // no integral types + {DataType::U64, DataType::U64, DataType::U64, false}, // no integral types + {DataType::U32, DataType::U32, DataType::U32, false}, // no integral types + {DataType::U16, DataType::U16, DataType::U16, false}, // no integral types + {DataType::U8, DataType::U8, DataType::U8, false}, // no integral types }; // Create a sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; const TensorShape shape = TensorShape(10U, 10U); - MatMulAttributes matmul_attr {}; + MatMulAttributes matmul_attr{}; matmul_attr.adj_lhs(false); matmul_attr.adj_rhs(false); - GpuMatMulSettings matmul_settings {}; + GpuMatMulSettings matmul_settings{}; matmul_settings.m0(1); matmul_settings.n0(1); matmul_settings.k0(1); - for(auto &tuple : data_type_configurations) + for (auto &tuple : data_type_configurations) { const bool expected = std::get<3>(tuple); - const TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<0>(tuple))); - const TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<1>(tuple))); + const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<0>(tuple))); + const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<1>(tuple))); - Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings); + Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings); ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } } @@ -250,59 +242,75 @@ TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL) TEST_SUITE_END() // Validate template -using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture; +using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunTiny, DynamicFusionGpuMatmulFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunTiny, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuMatmulFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunSmall, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, DynamicFusionGpuMatmulFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - m0_values_nightly_lhs_nt), - n0_values_nightly_rhs_t), - k0_values_nightly_rhs_t), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunLargeRhsTransposed, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + m0_values_nightly_lhs_nt), + n0_values_nightly_rhs_t), + k0_values_nightly_rhs_t), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } // Running High Dimensional test is enough for FP32, because we're stressing the number of dimensions, not data type or M0/N0/K0 -FIXTURE_DATA_TEST_CASE(RunHighDimensional, DynamicFusionGpuMatmulFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - framework::dataset::make("M0", { 2 })), - framework::dataset::make("N0", { 2 })), - framework::dataset::make("K0", { 2 })), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunHighDimensional, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + framework::dataset::make("M0", {2})), + framework::dataset::make("N0", {2})), + framework::dataset::make("K0", {2})), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); @@ -311,28 +319,35 @@ TEST_SUITE_END() // FP32 TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuMatmulFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE( + RunSmall, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } - -FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, DynamicFusionGpuMatmulFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), - framework::dataset::make("TransposeA", { false })), - framework::dataset::make("TransposeB", { true })), - m0_values_nightly_lhs_nt), - n0_values_nightly_rhs_t), - k0_values_nightly_rhs_t), - framework::dataset::make("ExportRhsToCLImage", { false })), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE( + RunLargeRhsTransposed, + DynamicFusionGpuMatmulFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), + framework::dataset::make("TransposeA", {false})), + framework::dataset::make("TransposeB", {true})), + m0_values_nightly_lhs_nt), + n0_values_nightly_rhs_t), + k0_values_nightly_rhs_t), + framework::dataset::make("ExportRhsToCLImage", {false})), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp index b69479fb7e..c11bffe459 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,14 +29,13 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h" #include "tests/CL/CLAccessor.h" -#include "tests/framework/Fixture.h" -#include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" - #include "tests/datasets/DynamicFusionDataset.h" #include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -58,8 +57,10 @@ namespace validation */ namespace { -constexpr AbsoluteTolerance tolerance_f16(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ -constexpr AbsoluteTolerance tolerance_f32(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +constexpr AbsoluteTolerance tolerance_f16( + 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr AbsoluteTolerance tolerance_f32( + 0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ } // namespace TEST_SUITE(CL) TEST_SUITE(DYNAMIC_FUSION) @@ -112,7 +113,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( auto lhs_info = context.create_tensor_info(input1_info); auto rhs_info = context.create_tensor_info(input2_info); - bool res = bool(GpuMul::validate_op(sketch, &lhs_info, &rhs_info)); + bool res = bool(GpuMul::validate_op(sketch, lhs_info, rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -129,9 +130,8 @@ TEST_SUITE(F16) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLMulFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); @@ -141,8 +141,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLMulBroadcastFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); @@ -152,8 +152,8 @@ FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp, DynamicFusionCLMulBroadcastFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); @@ -164,9 +164,8 @@ TEST_SUITE(F32) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLMulFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -175,9 +174,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunLargeOneOp, DynamicFusionCLMulFixture, framework::DatasetMode::NIGHTLY, - combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -187,8 +185,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLMulBroadcastFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -198,8 +196,8 @@ FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp, DynamicFusionCLMulBroadcastFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -209,9 +207,9 @@ FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionCLMulTwoOpsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false })), - framework::dataset::make("FuseTwoOps", { true }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false})), + framework::dataset::make("FuseTwoOps", {true}))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp index 411e31b32b..f894ce3cf1 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,13 +25,13 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h" #include "tests/CL/CLAccessor.h" -#include "tests/datasets/ShapeDatasets.h" #include "tests/datasets/dynamic_fusion/PoolingLayerDataset.h" +#include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/datasets/Datasets.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -43,15 +43,19 @@ TEST_SUITE(CL) TEST_SUITE(DYNAMIC_FUSION) TEST_SUITE(POOL2D) -constexpr AbsoluteTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */ -constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */ +constexpr AbsoluteTolerance tolerance_f32( + 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */ +constexpr AbsoluteTolerance tolerance_f16( + 0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */ -const auto PoolingLayerDatasetFP = combine(combine(combine(combine(framework::dataset::make("PoolingType", { PoolingType::MAX, PoolingType::AVG }), framework::dataset::make("PoolingSize", { Size2D(2, 2), Size2D(3, 3) })), - framework::dataset::make("Pad", { Padding2D() })), - framework::dataset::make("Stride", { Size2D(1, 1), Size2D(2, 1), Size2D(5, 7) })), - framework::dataset::make("ExcludePadding", { true })); +const auto PoolingLayerDatasetFP = + combine(combine(combine(combine(framework::dataset::make("PoolingType", {PoolingType::MAX, PoolingType::AVG}), + framework::dataset::make("PoolingSize", {Size2D(2, 2), Size2D(3, 3)})), + framework::dataset::make("Pad", {Padding2D()})), + framework::dataset::make("Stride", {Size2D(1, 1), Size2D(2, 1), Size2D(5, 7)})), + framework::dataset::make("ExcludePadding", {true})); -const auto pool_fp_mixed_precision_dataset = framework::dataset::make("FpMixedPrecision", { true, false }); +const auto pool_fp_mixed_precision_dataset = framework::dataset::make("FpMixedPrecision", {true, false}); template using DynamicFusionGpuPool2dFixture = DynamicFusionGpuPool2dValidationFixture; @@ -60,7 +64,8 @@ template using DFSpecialGpuPool2dFixture = DynamicFusionGpuPool2dSpecialValidationFixture; template -using DFPoolMixedPrecisionFixture = DynamicFusionGpuPool2dMixedPrecisionValidationFixture; +using DFPoolMixedPrecisionFixture = + DynamicFusionGpuPool2dMixedPrecisionValidationFixture; // *INDENT-OFF* // clang-format off @@ -91,7 +96,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( // Validate Pool2d Configuration auto src_info = context.create_tensor_info(input_info); - bool res = bool(GpuPool2d::validate_op(sketch, &src_info, pool2d_attr, settings)); + bool res = bool(GpuPool2d::validate_op(sketch, src_info, pool2d_attr, settings)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } @@ -100,53 +105,68 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::PRECOMMIT, + combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunLarge, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunSpecial, DFSpecialGpuPool2dFixture, framework::DatasetMode::ALL, combine(datasets::PoolingLayerDatasetSpecialDynamicFusion(), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSpecial, + DFSpecialGpuPool2dFixture, + framework::DatasetMode::ALL, + combine(datasets::PoolingLayerDatasetSpecialDynamicFusion(), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } TEST_SUITE(GlobalPooling) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine( - framework::dataset::make("InputShape", { TensorShape(27U, 13U, 2U), - TensorShape(27U, 13U, 2U, 4U) - }), - framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })), - framework::dataset::make("PoolingSize", { Size2D(27, 13) })), - framework::dataset::make("Pad", { Padding2D() })), - framework::dataset::make("Stride", { Size2D(1, 1) })), - framework::dataset::make("ExcludePadding", true)), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunSmall, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape", + {TensorShape(27U, 13U, 2U), + TensorShape(27U, 13U, 2U, 4U)}), + framework::dataset::make("PoolingType", + {PoolingType::AVG, PoolingType::MAX})), + framework::dataset::make("PoolingSize", {Size2D(27, 13)})), + framework::dataset::make("Pad", {Padding2D()})), + framework::dataset::make("Stride", {Size2D(1, 1)})), + framework::dataset::make("ExcludePadding", true)), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine( - framework::dataset::make("InputShape", { TensorShape(79U, 37U, 11U), - TensorShape(79U, 37U, 11U, 4U) - }), - framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })), - framework::dataset::make("PoolingSize", { Size2D(79, 37) })), - framework::dataset::make("Pad", { Padding2D() })), - framework::dataset::make("Stride", { Size2D(1, 1) })), - framework::dataset::make("ExcludePadding", true)), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE( + RunLarge, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape", + {TensorShape(79U, 37U, 11U), + TensorShape(79U, 37U, 11U, 4U)}), + framework::dataset::make("PoolingType", + {PoolingType::AVG, PoolingType::MAX})), + framework::dataset::make("PoolingSize", {Size2D(79, 37)})), + framework::dataset::make("Pad", {Padding2D()})), + framework::dataset::make("Stride", {Size2D(1, 1)})), + framework::dataset::make("ExcludePadding", true)), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -155,49 +175,61 @@ TEST_SUITE_END() // GlobalPooling TEST_SUITE_END() // FP32 TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, DFPoolMixedPrecisionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP), - framework::dataset::make("DataType", DataType::F16)), - pool_fp_mixed_precision_dataset)) +FIXTURE_DATA_TEST_CASE(RunSmall, + DFPoolMixedPrecisionFixture, + framework::DatasetMode::PRECOMMIT, + combine(combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP), + framework::dataset::make("DataType", DataType::F16)), + pool_fp_mixed_precision_dataset)) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, DFPoolMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP), - framework::dataset::make("DataType", DataType::F16)), - pool_fp_mixed_precision_dataset)) +FIXTURE_DATA_TEST_CASE(RunLarge, + DFPoolMixedPrecisionFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP), + framework::dataset::make("DataType", DataType::F16)), + pool_fp_mixed_precision_dataset)) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); } TEST_SUITE(GlobalPooling) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine( - framework::dataset::make("InputShape", { TensorShape(27U, 13U, 2U), - TensorShape(27U, 13U, 2U, 4U) - }), - framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })), - framework::dataset::make("PoolingSize", { Size2D(27, 13) })), - framework::dataset::make("Pad", { Padding2D() })), - framework::dataset::make("Stride", { Size2D(1, 1) })), - framework::dataset::make("ExcludePadding", true)), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE( + RunSmall, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape", + {TensorShape(27U, 13U, 2U), + TensorShape(27U, 13U, 2U, 4U)}), + framework::dataset::make("PoolingType", + {PoolingType::AVG, PoolingType::MAX})), + framework::dataset::make("PoolingSize", {Size2D(27, 13)})), + framework::dataset::make("Pad", {Padding2D()})), + framework::dataset::make("Stride", {Size2D(1, 1)})), + framework::dataset::make("ExcludePadding", true)), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine( - framework::dataset::make("InputShape", { TensorShape(79U, 37U, 11U), - TensorShape(79U, 37U, 11U, 4U) - }), - framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })), - framework::dataset::make("PoolingSize", { Size2D(79, 37) })), - framework::dataset::make("Pad", { Padding2D() })), - framework::dataset::make("Stride", { Size2D(1, 1) })), - framework::dataset::make("ExcludePadding", true)), - framework::dataset::make("DataType", DataType::F16))) +FIXTURE_DATA_TEST_CASE( + RunLarge, + DynamicFusionGpuPool2dFixture, + framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape", + {TensorShape(79U, 37U, 11U), + TensorShape(79U, 37U, 11U, 4U)}), + framework::dataset::make("PoolingType", + {PoolingType::AVG, PoolingType::MAX})), + framework::dataset::make("PoolingSize", {Size2D(79, 37)})), + framework::dataset::make("Pad", {Padding2D()})), + framework::dataset::make("Stride", {Size2D(1, 1)})), + framework::dataset::make("ExcludePadding", true)), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); @@ -209,7 +241,7 @@ TEST_SUITE_END() // FLOAT TEST_SUITE_END() // POOL2D TEST_SUITE_END() // DYNAMIC_FUSION TEST_SUITE_END() // CL -} -} -} +} // namespace validation +} // namespace test +} // namespace arm_compute #endif // ACL_INTERNAL_TEST_CKW_IN_DF diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp index 4d038b2780..43617fe1be 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,10 +24,10 @@ #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test if ACL_INTERNAL_TEST_CKW_IN_DF and the op has not been ported to ckw #include "tests/CL/CLAccessor.h" #include "tests/datasets/ReshapeLayerDataset.h" -#include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -39,41 +39,52 @@ TEST_SUITE(CL) TEST_SUITE(DYNAMIC_FUSION) TEST_SUITE(RESHAPE) -DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(framework::dataset::make("InputInfo", -{ - TensorInfo(TensorShape(9U, 5U, 7U, 3U), 1, DataType::F32), TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32), TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32) /*mismatching dimensions*/, -}), -framework::dataset::make("OutputShape", -{ - TensorShape(9U, 5U, 21U), - TensorShape(8U, 24U, 4U), - TensorShape(192U, 192U), -})), -framework::dataset::make("Expected", { true, true, false })), -input_info, output_shape, expected) +DATA_TEST_CASE(Validate, + framework::DatasetMode::ALL, + zip(zip(framework::dataset::make( + "InputInfo", + { + TensorInfo(TensorShape(9U, 5U, 7U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32) /*mismatching dimensions*/, + }), + framework::dataset::make("OutputShape", + { + TensorShape(9U, 5U, 21U), + TensorShape(8U, 24U, 4U), + TensorShape(192U, 192U), + })), + framework::dataset::make("Expected", {true, true, false})), + input_info, + output_shape, + expected) { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Create sketch tensors TensorShape input_shape = input_info.tensor_shape(); ARM_COMPUTE_UNUSED(input_shape); - TensorInfo src_info = context.create_tensor_info(input_info); + ITensorInfo *src_info = context.create_tensor_info(input_info); ReshapeAttributes attributes; attributes.shape(output_shape); - Status status = GpuReshape::validate_op(sketch, &src_info, attributes); + Status status = GpuReshape::validate_op(sketch, src_info, attributes); ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } template -using DynamicFusionGpuReshapeLayerFixture = DynamicFusionGpuReshapeLayerValidationFixture; +using DynamicFusionGpuReshapeLayerFixture = + DynamicFusionGpuReshapeLayerValidationFixture; TEST_SUITE(F32) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType", - DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuReshapeLayerFixture, + framework::DatasetMode::ALL, + combine(datasets::SmallReshapeLayerDataset(), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference); @@ -81,8 +92,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, fra TEST_SUITE_END() // F32 TEST_SUITE(F16) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType", - DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuReshapeLayerFixture, + framework::DatasetMode::ALL, + combine(datasets::SmallReshapeLayerDataset(), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference); @@ -90,8 +104,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, fram TEST_SUITE_END() // F16 TEST_SUITE(U8) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType", - DataType::U8))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuReshapeLayerFixture, + framework::DatasetMode::ALL, + combine(datasets::SmallReshapeLayerDataset(), + framework::dataset::make("DataType", DataType::U8))) { // Validate output validate(CLAccessor(_target), _reference); @@ -99,8 +116,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, f TEST_SUITE_END() // U8 TEST_SUITE(S8) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType", - DataType::S8))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuReshapeLayerFixture, + framework::DatasetMode::ALL, + combine(datasets::SmallReshapeLayerDataset(), + framework::dataset::make("DataType", DataType::S8))) { // Validate output validate(CLAccessor(_target), _reference); @@ -108,8 +128,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, fr TEST_SUITE_END() // S8 TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType", - DataType::S16))) +FIXTURE_DATA_TEST_CASE(RunSmall, + DynamicFusionGpuReshapeLayerFixture, + framework::DatasetMode::ALL, + combine(datasets::SmallReshapeLayerDataset(), + framework::dataset::make("DataType", DataType::S16))) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp index 5f99cd6d78..10915acfaa 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp @@ -1,5 +1,5 @@ /* -* Copyright (c) 2022-2023 Arm Limited. +* Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,8 +29,8 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/validation/Validation.h" #include "tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h" +#include "tests/validation/Validation.h" using namespace arm_compute::experimental::dynamic_fusion; namespace arm_compute @@ -41,10 +41,10 @@ namespace validation { namespace { -using datasets::ScaleShapesBaseDataSet; +using datasets::ScaleAlignCornersSamplingPolicySet; using datasets::ScaleInterpolationPolicySet; using datasets::ScaleSamplingPolicySet; -using datasets::ScaleAlignCornersSamplingPolicySet; +using datasets::ScaleShapesBaseDataSet; /** We consider vector size in byte 16 since the maximum size of * a vector used by @ref CLScaleKernel is currently 16-byte (float4). @@ -59,9 +59,9 @@ constexpr uint32_t num_elements_per_vector() /** Quantization information data set */ const auto QuantizationInfoSet = framework::dataset::make("QuantizationInfo", -{ - QuantizationInfo(0.5f, -1), -}); + { + QuantizationInfo(0.5f, -1), + }); /** Tolerance */ constexpr AbsoluteTolerance tolerance_q8(1); @@ -83,22 +83,20 @@ TEST_SUITE(RESIZE) TEST_SUITE(Validate) -const auto default_input_shape = TensorShape{ 2, 3, 3, 2 }; -const auto default_output_shape = TensorShape{ 4, 6, 3, 2 }; +const auto default_input_shape = TensorShape{2, 3, 3, 2}; +const auto default_output_shape = TensorShape{4, 6, 3, 2}; constexpr auto default_data_type = DataType::U8; constexpr auto default_data_layout = DataLayout::NHWC; TEST_CASE(NullPtr, framework::DatasetMode::ALL) { - const TensorInfo input_info = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout }; - const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout }; + const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout}; + const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout}; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; - - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // nullptr is given as input Status status = GpuResize::validate_op(sketch, nullptr, ResizeAttributes()); @@ -107,44 +105,43 @@ TEST_CASE(NullPtr, framework::DatasetMode::ALL) TEST_CASE(SupportDataType, framework::DatasetMode::ALL) { - const std::map supported_data_types = - { - { DataType::U8, true }, - { DataType::S8, false }, - { DataType::QSYMM8, false }, - { DataType::QASYMM8, true }, - { DataType::QASYMM8_SIGNED, true }, - { DataType::QSYMM8_PER_CHANNEL, false }, - { DataType::U16, false }, - { DataType::S16, true }, - { DataType::QSYMM16, false }, - { DataType::QASYMM16, false }, - { DataType::U32, false }, - { DataType::S32, false }, - { DataType::U64, false }, - { DataType::S64, false }, - { DataType::BFLOAT16, false }, - { DataType::F16, true }, - { DataType::F32, true }, - { DataType::F64, false }, - { DataType::SIZET, false }, + const std::map supported_data_types = { + {DataType::U8, true}, + {DataType::S8, false}, + {DataType::QSYMM8, false}, + {DataType::QASYMM8, true}, + {DataType::QASYMM8_SIGNED, true}, + {DataType::QSYMM8_PER_CHANNEL, false}, + {DataType::U16, false}, + {DataType::S16, true}, + {DataType::QSYMM16, false}, + {DataType::QASYMM16, false}, + {DataType::U32, false}, + {DataType::S32, false}, + {DataType::U64, false}, + {DataType::S64, false}, + {DataType::BFLOAT16, false}, + {DataType::F16, true}, + {DataType::F32, true}, + {DataType::F64, false}, + {DataType::SIZET, false}, }; - for(auto &kv : supported_data_types) + for (auto &kv : supported_data_types) { - const TensorInfo input_info = TensorInfo{ default_input_shape, 1, kv.first, default_data_layout }; + const TensorInfo input_info = TensorInfo{default_input_shape, 1, kv.first, default_data_layout}; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes; attributes.output_width(default_output_shape[0]); // shape is not important unless it's empty attributes.output_height(default_output_shape[1]); - Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes); + Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes); ARM_COMPUTE_EXPECT(bool(status) == kv.second, framework::LogLevel::ERRORS); } } @@ -153,16 +150,16 @@ TEST_CASE(MismatchingDataType, framework::DatasetMode::ALL) { constexpr DataType non_default_data_type = DataType::F32; - const TensorInfo input_info = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout }; - const TensorInfo output_info = TensorInfo{ default_output_shape, 1, non_default_data_type, default_data_layout }; + const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout}; + const TensorInfo output_info = TensorInfo{default_output_shape, 1, non_default_data_type, default_data_layout}; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info); - Status status = GpuResize::validate_op(sketch, &sketch_input_info, ResizeAttributes()); + Status status = GpuResize::validate_op(sketch, sketch_input_info, ResizeAttributes()); ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS); } @@ -173,59 +170,57 @@ TEST_CASE(AlignedCornerNotSupported, framework::DatasetMode::ALL) constexpr bool align_corners = true; constexpr SamplingPolicy sampling_policy = SamplingPolicy::CENTER; - const TensorInfo input_info = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout }; - const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout }; + const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout}; + const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout}; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; - attributes.interpolation_policy(interpolation_policy) - .sampling_policy(sampling_policy) - .align_corners(align_corners); + attributes.interpolation_policy(interpolation_policy).sampling_policy(sampling_policy).align_corners(align_corners); - Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes); + Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes); ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS); } TEST_CASE(UnsupportedInterpolationPolicy, framework::DatasetMode::ALL) { - const TensorInfo input_info = TensorInfo{ TensorShape(28U, 33U, 2U), 1, DataType::F32, default_data_layout }; - const TensorInfo output_info = TensorInfo{ TensorShape(26U, 21U, 2U), 1, DataType::F32, default_data_layout }; + const TensorInfo input_info = TensorInfo{TensorShape(28U, 33U, 2U), 1, DataType::F32, default_data_layout}; + const TensorInfo output_info = TensorInfo{TensorShape(26U, 21U, 2U), 1, DataType::F32, default_data_layout}; constexpr auto interpolation_policy = InterpolationPolicy::AREA; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; attributes.interpolation_policy(interpolation_policy); - Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes); + Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes); ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS); } TEST_CASE(UnsupportedLayout, framework::DatasetMode::ALL) { - const TensorInfo input_info = TensorInfo{ default_input_shape, 1, default_data_type, DataLayout::NCHW }; - const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, DataLayout::NCHW }; + const TensorInfo input_info = TensorInfo{default_input_shape, 1, default_data_type, DataLayout::NCHW}; + const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, DataLayout::NCHW}; constexpr auto interpolation_policy = InterpolationPolicy::BILINEAR; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + GpuWorkloadContext context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; - const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; attributes.interpolation_policy(interpolation_policy); - Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes); + Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes); ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS); } @@ -237,43 +232,60 @@ using DynamicFusionResizeFixture = DynamicFusionResizeValidationFixture())), framework::dataset::make("DataType", DataType::F32)); +const auto f32_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::F32)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute); } -const auto f32_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::F32)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleSamplingPolicySet)) +const auto f32_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::F32)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, - ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute); @@ -281,41 +293,58 @@ FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture TEST_SUITE_END() // FP32 TEST_SUITE(FP16) -const auto f16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::F16)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleSamplingPolicySet)) +const auto f16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::F16)); +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16); } -const auto f16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::F16)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleSamplingPolicySet)) +const auto f16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::F16)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, - ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16); @@ -325,41 +354,58 @@ TEST_SUITE_END() // Float TEST_SUITE(Integer) TEST_SUITE(U8) -const auto u8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::U8)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleSamplingPolicySet)) +const auto u8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::U8)); +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -const auto u8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::U8)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleSamplingPolicySet)) +const auto u8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::U8)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, - ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); @@ -367,41 +413,58 @@ FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture())), framework::dataset::make("DataType", DataType::S16)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleSamplingPolicySet)) +const auto s16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::S16)); +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_s16); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::ALL, + ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_s16); } -const auto s16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::S16)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleSamplingPolicySet)) +const auto s16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::S16)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_s16); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, - ScaleAlignCornersSamplingPolicySet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleAlignCornersSamplingPolicySet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_s16); @@ -410,50 +473,70 @@ TEST_SUITE_END() // S16 TEST_SUITE_END() // Integer template -using DynamicFusionResizeQuantizedFixture = DynamicFusionResizeQuantizedValidationFixture; +using DynamicFusionResizeQuantizedFixture = + DynamicFusionResizeQuantizedValidationFixture; TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -const auto qasymm8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::QASYMM8)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape, ScaleSamplingPolicySet, - QuantizationInfoSet)) +const auto qasymm8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::QASYMM8)); +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::ALL, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape, + ScaleSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape, - ScaleAlignCornersSamplingPolicySet, - QuantizationInfoSet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::ALL, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape, + ScaleAlignCornersSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -const auto qasymm8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::QASYMM8)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape, - ScaleSamplingPolicySet, - QuantizationInfoSet)) +const auto qasymm8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::QASYMM8)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape, + ScaleSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape, - ScaleAlignCornersSamplingPolicySet, - QuantizationInfoSet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape, + ScaleAlignCornersSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_q8); @@ -461,47 +544,66 @@ FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeQuantizedFixtu TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) -const auto qasymm8_signed_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); -FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape, ScaleSamplingPolicySet, - QuantizationInfoSet)) +const auto qasymm8_signed_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); +FIXTURE_DATA_TEST_CASE(Run, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::ALL, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape, + ScaleSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8); } -FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape, - ScaleAlignCornersSamplingPolicySet, - QuantizationInfoSet)) +FIXTURE_DATA_TEST_CASE(RunAlignCorners, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::ALL, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape, + ScaleAlignCornersSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8); } -const auto qasymm8_signed_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); -FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape, - ScaleSamplingPolicySet, - QuantizationInfoSet)) +const auto qasymm8_signed_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector())), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); +FIXTURE_DATA_TEST_CASE(RunNightly, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape, + ScaleSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8); } -FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeQuantizedFixture, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape, - ScaleAlignCornersSamplingPolicySet, - QuantizationInfoSet)) +FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, + DynamicFusionResizeQuantizedFixture, + framework::DatasetMode::NIGHTLY, + ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape, + ScaleAlignCornersSamplingPolicySet, + QuantizationInfoSet)) { //Create valid region TensorInfo src_info(_shape, 1, _data_type); - const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); + const ValidRegion valid_region = + calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false); // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp index e995511171..0134a7c11b 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,10 +29,10 @@ #include "tests/CL/CLAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" -#include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -65,9 +65,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( GpuWorkloadSketch sketch{ &context }; // Fuse sigmoid - const TensorInfo src_info = context.create_tensor_info(input_info); + const ITensorInfo *src_info = context.create_tensor_info(input_info); - const bool res = static_cast(GpuSigmoid::validate_op(sketch, &src_info)); + const bool res = static_cast(GpuSigmoid::validate_op(sketch, src_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -81,8 +81,7 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -92,8 +91,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::Small5dShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -104,8 +102,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { true })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -118,8 +115,7 @@ TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -129,8 +125,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::Small5dShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -141,8 +136,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionSigmoidOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { true })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F32))) { // Validate output diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp index 340f5dc2a3..b7cb6bace6 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,11 +28,11 @@ #include "tests/CL/CLAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" +#include "tests/framework/datasets/Datasets.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" #include "tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h" +#include "tests/validation/Validation.h" using namespace arm_compute::experimental::dynamic_fusion; @@ -110,9 +110,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( SoftmaxAttributes softmax_attr{}; softmax_attr.axis(axis).beta(beta).is_log_softmax(false); - TensorInfo src_info = context.create_tensor_info(input_info); - TensorInfo dst_info = context.create_tensor_info(output_info); - const bool res = static_cast(GpuSoftmax::validate_op(sketch, &src_info, &dst_info, softmax_attr)); + ITensorInfo* src_info = context.create_tensor_info(input_info); + ITensorInfo* dst_info = context.create_tensor_info(output_info); + const bool res = static_cast(GpuSoftmax::validate_op(sketch, src_info, dst_info, softmax_attr)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp index 022c9b46a8..ef9f75b1c0 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,14 +29,13 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSub.h" #include "tests/CL/CLAccessor.h" -#include "tests/framework/Fixture.h" -#include "tests/framework/Macros.h" -#include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" - #include "tests/datasets/DynamicFusionDataset.h" #include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -99,29 +98,32 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( auto lhs_info = context.create_tensor_info(input1_info); auto rhs_info = context.create_tensor_info(input2_info); - bool res = bool(GpuSub::validate_op(sketch, &lhs_info, &rhs_info)); + bool res = bool(GpuSub::validate_op(sketch, lhs_info, rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* template -using DynamicFusionCLSubFixture = DynamicFusionGpuElementwiseBinaryOneOpValidationFixture; +using DynamicFusionCLSubFixture = + DynamicFusionGpuElementwiseBinaryOneOpValidationFixture; template -using DynamicFusionCLSubBroadcastFixture = DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture; +using DynamicFusionCLSubBroadcastFixture = + DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture; template -using DynamicFusionCLSubTwoOpsFixture = DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture; +using DynamicFusionCLSubTwoOpsFixture = + DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture; TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLSubFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -129,10 +131,10 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunLargeOneOp, DynamicFusionCLSubFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::LargeShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -140,10 +142,10 @@ FIXTURE_DATA_TEST_CASE(RunLargeOneOp, FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLSubBroadcastFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::TemporaryLimitedSmallShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -152,22 +154,23 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp, DynamicFusionCLSubBroadcastFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::TemporaryLimitedLargeShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, - DynamicFusionCLSubTwoOpsFixture, - framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), - datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("InPlace", { false })), - framework::dataset::make("FuseTwoOps", { true }))) +FIXTURE_DATA_TEST_CASE( + RunSmallTwoOps, + DynamicFusionCLSubTwoOpsFixture, + framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), + datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()), + framework::dataset::make("DataType", {DataType::F32})), + framework::dataset::make("InPlace", {false})), + framework::dataset::make("FuseTwoOps", {true}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -178,10 +181,10 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionCLSubFixture, framework::DatasetMode::ALL, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -190,10 +193,10 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionCLSubBroadcastFixture, framework::DatasetMode::ALL, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::TemporaryLimitedSmallShapesBroadcast()), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::F16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -205,10 +208,10 @@ TEST_SUITE(S32) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLSubFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::S32 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S32})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -219,10 +222,10 @@ TEST_SUITE(S16) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLSubFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::S16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -230,10 +233,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionCLSubFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::LargeShapes()), - framework::dataset::make("DataType", { DataType::S16 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::S16})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); @@ -244,10 +247,10 @@ TEST_SUITE(U8) FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionCLSubFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }), + combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}), datasets::SmallShapes()), - framework::dataset::make("DataType", { DataType::U8 })), - framework::dataset::make("InPlace", { false }))) + framework::dataset::make("DataType", {DataType::U8})), + framework::dataset::make("InPlace", {false}))) { // Validate output validate(CLAccessor(_target), _reference); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp index 12f3677abf..2560f3aab1 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,10 +29,10 @@ #include "tests/CL/CLAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" -#include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" -#include "tests/validation/Validation.h" +#include "tests/framework/Macros.h" #include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -65,9 +65,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( GpuWorkloadSketch sketch{ &context }; // Fuse tanh - const TensorInfo src_info = context.create_tensor_info(input_info); + const ITensorInfo* src_info = context.create_tensor_info(input_info); - const bool res = static_cast(GpuTanh::validate_op(sketch, &src_info)); + const bool res = static_cast(GpuTanh::validate_op(sketch, src_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -81,8 +81,7 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -92,8 +91,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::Small5dShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -104,8 +102,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { true })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -118,8 +115,7 @@ TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -129,8 +125,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallOneOp, FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::Small5dShapes(), - framework::dataset::make("Fuse", { false })), + combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -141,8 +136,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp, FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionTanhOpFixture, framework::DatasetMode::ALL, - combine(combine(datasets::SmallShapes(), - framework::dataset::make("Fuse", { true })), + combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})), framework::dataset::make("DataType", DataType::F32))) { // Validate output -- cgit v1.2.1