diff options
Diffstat (limited to 'tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h')
-rw-r--r-- | tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h | 72 |
1 files changed, 24 insertions, 48 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h index d11237748f..f97b541ce3 100644 --- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h +++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h @@ -29,6 +29,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h" #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h" #include "tests/CL/CLAccessor.h" #include "tests/framework/Fixture.h" @@ -102,32 +103,30 @@ protected: auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; GpuWorkloadSketch sketch{ &gpu_ctx }; - TensorInfo dst_info{}; - TensorInfo dst_info_fuse{}; // Fuse first element wise binary Op auto lhs_info = sketch.create_tensor_info(shape0, 1, _data_type); auto rhs_info = sketch.create_tensor_info(TensorInfo(shape1, 1, _data_type)); + + auto ans_info = sketch.create_tensor_info(); + auto dst_info = sketch.create_tensor_info(); + TensorInfo rhs_info_fuse; + TensorInfo ans2_info; + + FunctionType::create_op(sketch, &lhs_info, &rhs_info, &ans_info); - // Testing root case while in-place - if(!_is_inplace) + if(_fuse) { - dst_info = sketch.create_tensor_info(TensorInfo(1, _data_type)); + rhs_info_fuse = sketch.create_tensor_info(shape2, 1, _data_type); + ans2_info = sketch.create_tensor_info(); - FunctionType::create_op(sketch, &lhs_info, &rhs_info, &dst_info); + FunctionType::create_op(sketch, &ans_info, &rhs_info_fuse, &ans2_info); + GpuOutput::create_op(sketch, &ans2_info, &dst_info); } else { - FunctionType::create_op(sketch, &lhs_info, &rhs_info, &lhs_info); - } - - if(_fuse) - { - // Fuse first element wise binary Op - rhs_info_fuse = sketch.create_tensor_info(TensorInfo(shape2, 1, _data_type)); - dst_info_fuse = sketch.create_tensor_info(); - FunctionType::create_op(sketch, &dst_info, &rhs_info_fuse, &dst_info_fuse); + GpuOutput::create_op(sketch, &ans_info, &dst_info); } // Configure runtime @@ -148,33 +147,24 @@ protected: TensorType t_rhs{}; TensorType t_rhs_fuse{}; TensorType t_dst{}; - TensorType t_dst_fuse{}; // Initialize user tensors t_lhs.allocator()->init(lhs_info); t_rhs.allocator()->init(rhs_info); - if(!_is_inplace) + t_dst.allocator()->init(dst_info); + if(_fuse) { - t_dst.allocator()->init(dst_info); - if(_fuse) - { - t_rhs_fuse.allocator()->init(rhs_info_fuse); - t_dst_fuse.allocator()->init(dst_info_fuse); - } + t_rhs_fuse.allocator()->init(rhs_info_fuse); } // Allocate and fill user tensors // Instead of using ACL allocator, the user can choose to import memory into the tensors t_lhs.allocator()->allocate(); t_rhs.allocator()->allocate(); - if(!_is_inplace) + t_dst.allocator()->allocate(); + if(_fuse) { - t_dst.allocator()->allocate(); - if(_fuse) - { - t_rhs_fuse.allocator()->allocate(); - t_dst_fuse.allocator()->allocate(); - } + t_rhs_fuse.allocator()->allocate(); } fill(AccessorType(t_lhs), 0); @@ -183,31 +173,17 @@ protected: { fill(AccessorType(t_rhs_fuse), 2); } + // Run runtime - if(_is_inplace) + if(_fuse) { - runtime.run({ &t_lhs, &t_rhs, &t_lhs }); + runtime.run({ &t_lhs, &t_rhs, &t_rhs_fuse, &t_dst }); } else { - if(_fuse) - { - runtime.run({ &t_lhs, &t_rhs, &t_rhs_fuse, &t_dst_fuse }); - } - else - { - runtime.run({ &t_lhs, &t_rhs, &t_dst }); - } + runtime.run({ &t_lhs, &t_rhs, &t_dst }); } - if(_is_inplace) - { - return t_lhs; - } - else if(_fuse) - { - return t_dst_fuse; - } return t_dst; } |