diff options
Diffstat (limited to 'tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h')
-rw-r--r-- | tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h | 122 |
1 files changed, 76 insertions, 46 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h index c6ac4b91db..65a3363e24 100644 --- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h +++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,7 +28,6 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" - #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h" #include "arm_compute/dynamic_fusion/sketch/attributes/MatMulAttributes.h" #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" @@ -39,10 +38,10 @@ #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" #include "tests/validation/Helpers.h" -#include "tests/validation/Validation.h" #include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/Permute.h" #include "tests/validation/reference/ReshapeLayer.h" +#include "tests/validation/Validation.h" using namespace arm_compute::experimental::dynamic_fusion; @@ -57,11 +56,11 @@ namespace template <typename U> void fill(U &&tensor, int i) { - switch(tensor.data_type()) + switch (tensor.data_type()) { case DataType::F16: { - arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f }; + arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-1.0f, 1.0f}; library->fill(tensor, distribution, i); break; } @@ -80,67 +79,83 @@ void fill(U &&tensor, int i) template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class DynamicFusionGpuMatMulValidationGenericFixture : public framework::Fixture { - public: - void setup(TensorShape lhs_shape, TensorShape rhs_shape, TensorShape output_shape, bool transpose_a, bool transpose_b, - int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type) + void setup(TensorShape lhs_shape, + TensorShape rhs_shape, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + int M0, + int N0, + int K0, + bool export_rhs_to_cl_image, + DataType data_type) { //For brevity, the input shapes are assumed to be not-transposed for both a and b matrices. - if(transpose_a) + if (transpose_a) { permute(lhs_shape, PermutationVector(1U, 0U)); } - if(transpose_b) + if (transpose_b) { permute(rhs_shape, PermutationVector(1U, 0U)); } // Skip configurations unsupported by the device. _device_supports_export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device()); - if(!_device_supports_export_to_cl_image && export_rhs_to_cl_image) + if (!_device_supports_export_to_cl_image && export_rhs_to_cl_image) { ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped"); framework::ARM_COMPUTE_PRINT_INFO(); return; // Note: Also need to skip the validate in corresponding FIXTURE_DATA_TEST_CASEs. } - _target = compute_target(lhs_shape, rhs_shape, transpose_a, transpose_b, M0, N0, K0, export_rhs_to_cl_image, data_type); + _target = compute_target(lhs_shape, rhs_shape, transpose_a, transpose_b, M0, N0, K0, export_rhs_to_cl_image, + data_type); _reference = compute_reference(lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, data_type); } protected: - TensorType compute_target(TensorShape &shape_a, TensorShape &shape_b, bool transpose_a, bool transpose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type) + TensorType compute_target(TensorShape &shape_a, + TensorShape &shape_b, + bool transpose_a, + bool transpose_b, + int M0, + int N0, + int K0, + bool export_rhs_to_cl_image, + DataType data_type) { ARM_COMPUTE_UNUSED(export_rhs_to_cl_image); CLScheduler::get().default_reinit(); // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); - auto context = GpuWorkloadContext{ &cl_compile_ctx }; - GpuWorkloadSketch sketch{ &context }; + auto context = GpuWorkloadContext{&cl_compile_ctx}; + GpuWorkloadSketch sketch{&context}; // Create sketch tensors - TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape_a, 1, data_type)); - TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape_b, 1, data_type)); - TensorInfo dst_info = context.create_tensor_info(); + ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape_a, 1, data_type)); + ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape_b, 1, data_type)); + ITensorInfo *dst_info = context.create_tensor_info(); - MatMulAttributes matmul_attr {}; + MatMulAttributes matmul_attr{}; matmul_attr.adj_lhs(transpose_a); matmul_attr.adj_rhs(transpose_b); - GpuMatMulSettings matmul_settings {}; + GpuMatMulSettings matmul_settings{}; matmul_settings.m0(M0); matmul_settings.n0(N0); matmul_settings.k0(K0); - ITensorInfo *ans_info = FunctionType::create_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings); - GpuOutput::create_op(sketch, ans_info, &dst_info); + ITensorInfo *ans_info = FunctionType::create_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings); + GpuOutput::create_op(sketch, ans_info, dst_info); // Configure runtime ClWorkloadRuntime runtime; runtime.configure(sketch); - for(auto &data : runtime.get_auxiliary_tensors()) + for (auto &data : runtime.get_auxiliary_tensors()) { CLTensor *tensor = std::get<0>(data); TensorInfo info = std::get<1>(data); @@ -155,9 +170,9 @@ protected: TensorType t_dst{}; // Initialize user tensors - t_lhs.allocator()->init(lhs_info); - t_rhs.allocator()->init(rhs_info); - t_dst.allocator()->init(dst_info); + t_lhs.allocator()->init(*lhs_info); + t_rhs.allocator()->init(*rhs_info); + t_dst.allocator()->init(*dst_info); ARM_COMPUTE_ASSERT(t_lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(t_rhs.info()->is_resizable()); @@ -176,12 +191,17 @@ protected: fill(AccessorType(t_rhs), 1); // Run runtime - runtime.run({ &t_lhs, &t_rhs, &t_dst }); + runtime.run({&t_lhs, &t_rhs, &t_dst}); return t_dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &shape_a, + const TensorShape &shape_b, + const TensorShape &output_shape, + bool pretranspose_a, + bool pretranspose_b, + DataType data_type) { // We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D // This is necessary unless we choose to extend gemm reference for 5D+ tensors @@ -190,9 +210,9 @@ protected: TensorShape shape_b_collapsed = shape_b.collapsed_from(Window::DimZ); // Create reference - SimpleTensor<T> a{ shape_a_collapsed, data_type, 1 }; - SimpleTensor<T> b{ shape_b_collapsed, data_type, 1 }; - SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 }; + SimpleTensor<T> a{shape_a_collapsed, data_type, 1}; + SimpleTensor<T> b{shape_b_collapsed, data_type, 1}; + SimpleTensor<T> c{output_shape_collapsed, data_type, 1}; // Fill reference fill(a, 0); @@ -213,27 +233,27 @@ protected: b_transposed_shape.set(1, b.shape().x()); // Define transposed tensors - SimpleTensor<T> a_transposed{ a_transposed_shape, data_type }; - SimpleTensor<T> b_transposed{ b_transposed_shape, data_type }; + SimpleTensor<T> a_transposed{a_transposed_shape, data_type}; + SimpleTensor<T> b_transposed{b_transposed_shape, data_type}; //pretranspose a if necessary - if(pretranspose_a) + if (pretranspose_a) { a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U)); } // pretranspose b if necessary - if(pretranspose_b) + if (pretranspose_b) { b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U)); } // Use transposed tensors if boolean enabled else use original tensors - SimpleTensor<T> result = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f); - + SimpleTensor<T> result = + reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f); // We reshape the gemm output back if the tensor is high dimensional - if(output_shape_collapsed != output_shape) + if (output_shape_collapsed != output_shape) { // std::cout << "called reshape: \n"; result = reference::reshape_layer(result, output_shape); @@ -244,20 +264,30 @@ protected: CLTensor _target{}; SimpleTensor<T> _reference{}; - bool _device_supports_export_to_cl_image{ false }; - bool _device_supports_mmul{ false }; + bool _device_supports_export_to_cl_image{false}; + bool _device_supports_mmul{false}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class DynamicFusionGpuMatMulValidationFixture : public DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +class DynamicFusionGpuMatMulValidationFixture + : public DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { - public: - void setup(TensorShape lhs_shape, TensorShape rhs_shape, TensorShape output_shape, bool transpose_a, bool transpose_b, - int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type) +public: + void setup(TensorShape lhs_shape, + TensorShape rhs_shape, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + int M0, + int N0, + int K0, + bool export_rhs_to_cl_image, + DataType data_type) { ARM_COMPUTE_UNUSED(export_rhs_to_cl_image); - DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, M0, - N0, K0, false /* export_rhs_to_cl_image bias */, data_type); + DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup( + lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, M0, N0, K0, + false /* export_rhs_to_cl_image bias */, data_type); } }; |