From 3fcf3dcf7b6ffc613468ccaca580bde495677440 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 17 May 2023 15:17:48 +0100 Subject: Add multi-sketch support for dynamic fusion * Tensors are owned by workload context instead of workload sketch so that they can be used by multiple sketches. * Add an integration test for multi-sketch case. Resolves: COMPMID-6148 Signed-off-by: Viet-Hoa Do Change-Id: I37d0de5ac103fb2a85020aa1c26e49eb304f47b7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9706 Comments-Addressed: Arm Jenkins Reviewed-by: SiCong Li Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- .../validation/dynamic_fusion/gpu/Integration.cpp | 262 +++++++++++++++++++-- tests/validation/dynamic_fusion/gpu/cl/Add.cpp | 8 +- tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp | 6 +- .../dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp | 10 +- .../dynamic_fusion/gpu/cl/DirectConv2d.cpp | 10 +- tests/validation/dynamic_fusion/gpu/cl/Mul.cpp | 8 +- tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp | 8 +- tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp | 6 +- tests/validation/dynamic_fusion/gpu/cl/Resize.cpp | 36 +-- tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp | 6 +- tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp | 8 +- tests/validation/dynamic_fusion/gpu/cl/Sub.cpp | 8 +- tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp | 6 +- 13 files changed, 298 insertions(+), 84 deletions(-) (limited to 'tests/validation/dynamic_fusion') diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp index 6a283f8082..3a915779c1 100644 --- a/tests/validation/dynamic_fusion/gpu/Integration.cpp +++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp @@ -23,24 +23,33 @@ */ #include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/QuantizationInfo.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h" #include "arm_compute/dynamic_fusion/sketch/attributes/CastAttributes.h" #include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h" +#include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h" #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuCast.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h" #include "tests/CL/CLAccessor.h" #include "tests/framework/Macros.h" #include "tests/validation/Validation.h" #include "tests/validation/dynamic_fusion/Utils.h" +#include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ConvolutionLayer.h" #include "tests/validation/reference/DepthConvertLayer.h" +#include "tests/validation/reference/DepthwiseConvolutionLayer.h" #include "tests/validation/reference/ElementwiseOperations.h" #include "tests/validation/reference/Permute.h" +#include "tests/validation/reference/PixelWiseMultiplication.h" using namespace arm_compute::experimental::dynamic_fusion; using namespace arm_compute::test::validation::utils; @@ -69,17 +78,17 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Fuse conv2d Conv2dAttributes conv2d_attr{}; - TensorInfo input_info = sketch.create_tensor_info(t_input_shape, 1, data_type, data_layout); - TensorInfo weight_info = sketch.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); + TensorInfo input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); + TensorInfo weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr); - TensorInfo dst_info = sketch.create_tensor_info(); + TensorInfo dst_info = context.create_tensor_info(); GpuOutput::create_op(sketch, conv_out_info, &dst_info); // Configure runtime @@ -156,15 +165,15 @@ TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - TensorInfo in_0_info = sketch.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_1_info = sketch.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_2_info = sketch.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo out_0_info = sketch.create_tensor_info(); - TensorInfo out_1_info = sketch.create_tensor_info(); + TensorInfo out_0_info = context.create_tensor_info(); + TensorInfo out_1_info = context.create_tensor_info(); ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info); GpuOutput::create_op(sketch, ans_0_info, &out_0_info); @@ -253,15 +262,15 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - TensorInfo in_0_info = sketch.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_1_info = sketch.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo in_2_info = sketch.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type); + TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type); - TensorInfo out_0_info = sketch.create_tensor_info(); - TensorInfo out_1_info = sketch.create_tensor_info(); + TensorInfo out_0_info = context.create_tensor_info(); + TensorInfo out_1_info = context.create_tensor_info(); CastAttributes cast_0_attr; cast_0_attr.data_type(DataType::S32).convert_policy(ConvertPolicy::SATURATE); @@ -348,6 +357,211 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL) validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_add_f32); validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_cast_f32); } + +TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL) +{ + // (tensor0) + // | + // ======|============================================== Sketch 0 + // | (tensor1) +---- (tensor2) + // | | | | + // +-- input -- weights -- biases --+ | + // | | | + // | Conv2d | | + // | | | + // +----------- output -------------+ | + // | | + // +-- input ---+ | + // | | | + // | Sigmoid | | + // | | | + // +-- output --+ | + // | | + // +-- input ---+ | + // | | | + // | Output | | + // | | | + // +-- output --+ | + // | | + // (tensor5) | + // | | + // +--------+ | + // ======|=============================|================ Sketch 1 + // | (tensor3) (tensor4) | + // | | | | + // +-- input -- weights -- biases --+ | + // | | | + // | DepthwiseConv2d | | + // | | | + // +----------- output -------------+ | + // | | + // +--+ +----------------+ + // | | + // +-- lhs -- rhs --+ + // | | + // | Multiply | + // | | + // +---- output ----+ + // | + // +-- input ---+ + // | | + // | Output | + // | | + // +-- output --+ + // | + // (tensor6) + + TensorShape conv2d_src_shape(10, 20, 30); + TensorShape conv2d_wei_shape(10, 3, 3, 5); + TensorShape conv2d_bia_shape(5); + TensorShape conv2d_dst_shape(5, 18, 28); + TensorShape dwc_wei_shape(5, 3, 3); + TensorShape dwc_bia_shape(5); + TensorShape dwc_dst_shape(5, 16, 26); + + // Initialize the context. + CLScheduler::get().default_reinit(); + + auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + GpuWorkloadContext context(&cl_compile_ctx); + + auto tensor0_info = context.create_tensor_info(conv2d_src_shape, 1, DataType::F32, DataLayout::NHWC); + + // Create the first sketch: conv2d + cast + output. + GpuWorkloadSketch sketch0(&context); + + Conv2dAttributes conv2d_attr; + auto tensor1_info = context.create_tensor_info(conv2d_wei_shape, 1, DataType::F32, DataLayout::NHWC); + auto tensor2_info = context.create_tensor_info(conv2d_bia_shape, 1, DataType::F32, DataLayout::NHWC); + ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr), framework::LogLevel::ERRORS); + auto ans_info = GpuConv2d::create_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr); + + ARM_COMPUTE_EXPECT(GpuSigmoid::validate_op(sketch0, ans_info), framework::LogLevel::ERRORS); + ans_info = GpuSigmoid::create_op(sketch0, ans_info); + + DepthwiseConv2dAttributes dwc_attr; + auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC); + auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC); + ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS); + + auto tensor5_info = context.create_tensor_info(); + ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, &tensor5_info), framework::LogLevel::ERRORS); + GpuOutput::create_op(sketch0, ans_info, &tensor5_info); + + // Create the first workload runtime. + ClWorkloadRuntime runtime0; + runtime0.configure(sketch0); + + // Create the second sketch: dwc + sigmoid + output. + GpuWorkloadSketch sketch1(&context); + + ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS); + ans_info = GpuDepthwiseConv2d::create_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr); + + ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, &tensor2_info), framework::LogLevel::ERRORS); + ans_info = GpuMul::create_op(sketch1, ans_info, &tensor2_info); + + auto tensor6_info = context.create_tensor_info(); + ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, &tensor6_info), framework::LogLevel::ERRORS); + GpuOutput::create_op(sketch1, ans_info, &tensor6_info); + + // Create the second workload runtime. + ClWorkloadRuntime runtime1; + runtime1.configure(sketch1); + + // Create the user tensors. + CLTensor tensor0; + CLTensor tensor1; + CLTensor tensor2; + CLTensor tensor3; + CLTensor tensor4; + CLTensor tensor5; + CLTensor tensor6; + + tensor0.allocator()->init(tensor0_info); + tensor1.allocator()->init(tensor1_info); + tensor2.allocator()->init(tensor2_info); + tensor3.allocator()->init(tensor3_info); + tensor4.allocator()->init(tensor4_info); + tensor5.allocator()->init(tensor5_info); + tensor6.allocator()->init(tensor6_info); + + tensor0.allocator()->allocate(); + tensor1.allocator()->allocate(); + tensor2.allocator()->allocate(); + tensor3.allocator()->allocate(); + tensor4.allocator()->allocate(); + tensor5.allocator()->allocate(); + tensor6.allocator()->allocate(); + + // Allocate the auxiliary tensors. + for(auto &data : runtime0.get_auxiliary_tensors()) + { + auto tensor = std::get<0>(data); + auto &tensor_info = std::get<1>(data); + auto mem_req = std::get<2>(data); + + tensor->allocator()->init(tensor_info, mem_req.alignment); + tensor->allocator()->allocate(); + } + + for(auto &data : runtime1.get_auxiliary_tensors()) + { + auto tensor = std::get<0>(data); + auto &tensor_info = std::get<1>(data); + auto mem_req = std::get<2>(data); + + tensor->allocator()->init(tensor_info, mem_req.alignment); + tensor->allocator()->allocate(); + } + + // Fill the input tensors with random data. + fill(CLAccessor(tensor0), 0, library.get()); + fill(CLAccessor(tensor1), 1, library.get()); + fill(CLAccessor(tensor2), 2, library.get()); + fill(CLAccessor(tensor3), 3, library.get()); + fill(CLAccessor(tensor4), 4, library.get()); + + // Run each runtime. + runtime0.run({ &tensor0, &tensor1, &tensor2, &tensor5 }); + runtime1.run({ &tensor5, &tensor3, &tensor4, &tensor2, &tensor6 }); + + // Compute the reference result. + SimpleTensor ref_conv2d_src(conv2d_src_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + SimpleTensor ref_conv2d_wei(conv2d_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + SimpleTensor ref_conv2d_bia(conv2d_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + SimpleTensor ref_dwc_wei(dwc_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + SimpleTensor ref_dwc_bia(dwc_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + + fill(ref_conv2d_src, 0, library.get()); + fill(ref_conv2d_wei, 1, library.get()); + fill(ref_conv2d_bia, 2, library.get()); + fill(ref_dwc_wei, 3, library.get()); + fill(ref_dwc_bia, 4, library.get()); + + PermutationVector nhwc_to_nchw(1, 2, 0); + + auto conv2d_dst_shape_nchw = conv2d_dst_shape; + permute(conv2d_dst_shape_nchw, nhwc_to_nchw); + const auto ref_conv2d_src_nchw = reference::permute(ref_conv2d_src, nhwc_to_nchw); + const auto ref_conv2d_wei_nchw = reference::permute(ref_conv2d_wei, nhwc_to_nchw); + const auto ref_conv2d_bia_nchw = reference::permute(ref_conv2d_bia, nhwc_to_nchw); + const auto ref_conv2d_dst_nchw = reference::convolution_layer(ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo()); + + const auto ref_sigmoid_dst_nchw = reference::activation_layer(ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + + auto dwc_dst_shape_nchw = dwc_dst_shape; + permute(dwc_dst_shape_nchw, nhwc_to_nchw); + const auto ref_dwc_wei_nchw = reference::permute(ref_dwc_wei, nhwc_to_nchw); + const auto ref_dwc_bia_nchw = reference::permute(ref_dwc_bia, nhwc_to_nchw); + const auto ref_dwc_dst_nchw = reference::depthwise_convolution(ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1); + + const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication(ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, DataType::F32); + + constexpr RelativeTolerance tolerance(0.001f); + validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance); +} + TEST_SUITE(Invalid_Fusion_Should_Fail) TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL) { @@ -368,12 +582,12 @@ TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL) // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Create tensor infos - TensorInfo input_info = sketch.create_tensor_info(t_input_shape, 1, data_type, data_layout); - TensorInfo weight_info = sketch.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); + TensorInfo input_info = context.create_tensor_info(t_input_shape, 1, data_type, data_layout); + TensorInfo weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout)); ITensorInfo *dst_info; // Fuse conv2d into the workload @@ -386,7 +600,7 @@ TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL) } // Create tensor infos - TensorInfo weight_info_2 = sketch.create_tensor_info(t_weight_info); + TensorInfo weight_info_2 = context.create_tensor_info(t_weight_info); // Fuse conv2d into the workload { diff --git a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp index 0034b0f07f..d9a3d9533c 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp @@ -87,12 +87,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Validate Elementwise Add - auto lhs_info = sketch.create_tensor_info(input1_info); - auto rhs_info = sketch.create_tensor_info(input2_info); + auto lhs_info = context.create_tensor_info(input1_info); + auto rhs_info = context.create_tensor_info(input2_info); bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp index 177c02c2c7..dc46dd594e 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp @@ -69,11 +69,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( { // Create a new workload sketch CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Fuse Clamp - const TensorInfo src_info = sketch.create_tensor_info(input_info); + const TensorInfo src_info = context.create_tensor_info(input_info); ClampAttributes attributes {}; attributes.min_val(min_val) diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp index b6331d70c8..7ab7c8a3fc 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp @@ -242,12 +242,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi input_info, weights_info, biases_info, padding, stride, depth_multiplier, dilation, expected) { CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); - const TensorInfo sketch_weights_info = sketch.create_tensor_info(weights_info); - const TensorInfo sketch_biases_info = sketch.create_tensor_info(biases_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info); + const TensorInfo sketch_biases_info = context.create_tensor_info(biases_info); DepthwiseConv2dAttributes attributes {}; attributes.pad(padding) diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp index cccad182ca..f27a1796c9 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp @@ -157,12 +157,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( input_info, weights_info, biases_info, conv2d_attrs, expected) { auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); - const TensorInfo sketch_weights_info = sketch.create_tensor_info(weights_info); - const TensorInfo sketch_biases_info = sketch.create_tensor_info(biases_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); + const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info); + const TensorInfo sketch_biases_info = context.create_tensor_info(biases_info); bool is_valid = bool(GpuConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, conv2d_attrs)); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } diff --git a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp index a9e8f9c15f..2da2b9eabd 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp @@ -102,12 +102,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Validate Elementwise Mul - auto lhs_info = sketch.create_tensor_info(input1_info); - auto rhs_info = sketch.create_tensor_info(input2_info); + auto lhs_info = context.create_tensor_info(input1_info); + auto rhs_info = context.create_tensor_info(input2_info); bool res = bool(GpuMul::validate_op(sketch, &lhs_info, &rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp index a7772aef4d..f4478db42b 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp @@ -101,15 +101,15 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Declare GpuPool2d settings const GpuPool2dSettings &settings = GpuPool2dSettings().mixed_precision(false); // Validate Pool2d Configuration - auto src_info = sketch.create_tensor_info(input_info); - auto dst_info = sketch.create_tensor_info(output_info); + auto src_info = context.create_tensor_info(input_info); + auto dst_info = context.create_tensor_info(output_info); bool res = bool(GpuPool2d::validate_op(sketch, &src_info, &dst_info, pool2d_attr, settings)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp index 6d88be448e..bdaa1be531 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp @@ -53,13 +53,13 @@ input_info, output_shape, expected) { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Create sketch tensors TensorShape input_shape = input_info.tensor_shape(); ARM_COMPUTE_UNUSED(input_shape); - TensorInfo src_info = sketch.create_tensor_info(input_info); + TensorInfo src_info = context.create_tensor_info(input_info); ReshapeAttributes attributes; attributes.shape(output_shape); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp index 696be54c92..5f99cd6d78 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp @@ -95,10 +95,10 @@ TEST_CASE(NullPtr, framework::DatasetMode::ALL) const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout }; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); // nullptr is given as input Status status = GpuResize::validate_op(sketch, nullptr, ResizeAttributes()); @@ -135,10 +135,10 @@ TEST_CASE(SupportDataType, framework::DatasetMode::ALL) const TensorInfo input_info = TensorInfo{ default_input_shape, 1, kv.first, default_data_layout }; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes; attributes.output_width(default_output_shape[0]); // shape is not important unless it's empty @@ -157,10 +157,10 @@ TEST_CASE(MismatchingDataType, framework::DatasetMode::ALL) const TensorInfo output_info = TensorInfo{ default_output_shape, 1, non_default_data_type, default_data_layout }; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); Status status = GpuResize::validate_op(sketch, &sketch_input_info, ResizeAttributes()); ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS); @@ -177,10 +177,10 @@ TEST_CASE(AlignedCornerNotSupported, framework::DatasetMode::ALL) const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout }; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; attributes.interpolation_policy(interpolation_policy) @@ -198,10 +198,10 @@ TEST_CASE(UnsupportedInterpolationPolicy, framework::DatasetMode::ALL) constexpr auto interpolation_policy = InterpolationPolicy::AREA; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; attributes.interpolation_policy(interpolation_policy); @@ -217,10 +217,10 @@ TEST_CASE(UnsupportedLayout, framework::DatasetMode::ALL) constexpr auto interpolation_policy = InterpolationPolicy::BILINEAR; CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; - const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info); + const TensorInfo sketch_input_info = context.create_tensor_info(input_info); ResizeAttributes attributes{}; attributes.interpolation_policy(interpolation_policy); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp index aace23eff4..5fd11807bc 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp @@ -61,11 +61,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( { // Create a new workload sketch CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Fuse sigmoid - const TensorInfo src_info = sketch.create_tensor_info(input_info); + const TensorInfo src_info = context.create_tensor_info(input_info); const bool res = static_cast(GpuSigmoid::validate_op(sketch, &src_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp index d09454e05b..e8314d700d 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp @@ -104,13 +104,13 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( { // Create a new workload sketch CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; SoftmaxAttributes softmax_attr{}; softmax_attr.axis(axis).beta(beta).is_log_softmax(false); - TensorInfo src_info = sketch.create_tensor_info(input_info); - TensorInfo dst_info = sketch.create_tensor_info(output_info); + TensorInfo src_info = context.create_tensor_info(input_info); + TensorInfo dst_info = context.create_tensor_info(output_info); const bool res = static_cast(GpuSoftmax::validate_op(sketch, &src_info, &dst_info, softmax_attr)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); } diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp index 977e0110da..0bb05c2961 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp @@ -89,12 +89,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + auto context = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Validate Elementwise Sub - auto lhs_info = sketch.create_tensor_info(input1_info); - auto rhs_info = sketch.create_tensor_info(input2_info); + auto lhs_info = context.create_tensor_info(input1_info); + auto rhs_info = context.create_tensor_info(input2_info); bool res = bool(GpuSub::validate_op(sketch, &lhs_info, &rhs_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp index 183cd079a3..00c92fbfc2 100644 --- a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp +++ b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp @@ -61,11 +61,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( { // Create a new workload sketch CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - GpuWorkloadContext gpu_ctx{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &gpu_ctx }; + GpuWorkloadContext context{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &context }; // Fuse tanh - const TensorInfo src_info = sketch.create_tensor_info(input_info); + const TensorInfo src_info = context.create_tensor_info(input_info); const bool res = static_cast(GpuTanh::validate_op(sketch, &src_info)); ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS); -- cgit v1.2.1