From bbeef721c285d467d003a739a1e68b2c86899750 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 20 Mar 2023 10:19:10 +0000 Subject: Add Texture Pipe Support for Matmul Lhs T/NT Rhs NT kernels Resolves: COMPMID-5945, COMPMID-5954 Change-Id: I7b27021d21f8e08c4896f6b1f595a75125064f9e Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9356 Reviewed-by: Gian Marco Iodice Reviewed-by: SiCong Li Reviewed-by: Viet-Hoa Do Benchmark: Arm Jenkins Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/CL/CLHelpers.cpp | 6 +- src/core/CL/cl_kernels/common/mat_mul.cl | 28 +-- src/gpu/cl/kernels/ClNativeMatMulKernel.cpp | 71 +++++++- src/gpu/cl/kernels/ClNativeMatMulKernel.h | 3 + tests/datasets/LargeMatMulDataset.h | 12 ++ tests/datasets/SmallMatMulDataset.h | 16 ++ tests/validation/CL/MatMulKernel.cpp | 221 ++++++++++++++++++++---- tests/validation/fixtures/MatMulKernelFixture.h | 15 +- 8 files changed, 305 insertions(+), 67 deletions(-) diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index b31864211c..6b011f1f7c 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2022 Arm Limited. + * Copyright (c) 2016-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -443,7 +443,7 @@ void set_wbsm(cl::Kernel &kernel, cl_int wbsm_hint) bool export_to_cl_image(const ITensorInfo *tensor) { - if(tensor->tensor_shape()[0] % 4) + if(tensor->tensor_shape()[0] % 4 != 0) { return false; } @@ -467,7 +467,7 @@ bool export_to_cl_image(const ITensorInfo *tensor) } const size_t image_w = tensor->tensor_shape()[0] / 4; - const size_t image_h = tensor->tensor_shape()[1] * tensor->tensor_shape()[2] * tensor->tensor_shape()[3]; + const size_t image_h = tensor->tensor_shape().total_size() / tensor->tensor_shape()[0]; const size_t max_image_w = CLKernelLibrary::get().get_device().getInfo(); const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo(); diff --git a/src/core/CL/cl_kernels/common/mat_mul.cl b/src/core/CL/cl_kernels/common/mat_mul.cl index 956d37a9d8..90ebf80a6a 100644 --- a/src/core/CL/cl_kernels/common/mat_mul.cl +++ b/src/core/CL/cl_kernels/common/mat_mul.cl @@ -33,10 +33,11 @@ * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4). * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3) * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6) + * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER) * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_NT) * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 > 0 - * - N0 = 1, 2, 3, 4, 8, 16 + * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE) * - K0 = 1, 2, 3, 4, 8, 16 * @note Values > 8 for M0 are not expected to be efficient * @@ -47,6 +48,7 @@ * @param[in] lhs_h The height of the lhs tensor * @param[in] lhs_n Number of the matrices (buffers) in the batch * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix + * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes) * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes) @@ -64,7 +66,7 @@ */ __kernel void mat_mul_native_nt_nt( TENSOR3D_T(lhs, BUFFER), - TENSOR3D_T(rhs, BUFFER), + TENSOR3D_T(rhs, RHS_TENSOR_TYPE), TENSOR3D_T(dst, BUFFER)) { const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0); @@ -73,7 +75,6 @@ __kernel void mat_mul_native_nt_nt( // Compute LHS/RHS/DST matrix address lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z; - rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z; dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z; // Initialize the accumulators @@ -84,6 +85,7 @@ __kernel void mat_mul_native_nt_nt( acc[i].v = 0.f; }) + const int rhs_z = z * rhs_h; int k; for(k = 0; k <= K - K0; k += K0) { @@ -102,12 +104,11 @@ __kernel void mat_mul_native_nt_nt( // Load tile from the lhs/rhs tensors T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); - T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b); + T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b); T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, b, acc); lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE); - rhs_offset_first_element_in_bytes += K0 * rhs_stride_y; } #ifdef K % K0 != 0 @@ -129,12 +130,11 @@ __kernel void mat_mul_native_nt_nt( // Load tile from the lhs/rhs tensors T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); - T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b); + T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b); T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, b, acc); lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE); - rhs_offset_first_element_in_bytes += 1 * rhs_stride_y; } #endif // K % K0 != 0 @@ -314,10 +314,11 @@ __kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER), * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4). * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3) * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6) + * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER) * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT) * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 = 1, 2, 3, 4, 8, 16 - * - N0 = 1, 2, 3, 4, 8, 16 + * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE) * - K0 > 0 * * @note Values > 8 for M0, and K0 are not expected to be efficient * @@ -328,6 +329,7 @@ __kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER), * @param[in] lhs_h The height of the lhs tensor * @param[in] lhs_n Number of the matrices (buffers) in the batch * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix + * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes) * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes) @@ -345,7 +347,7 @@ __kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER), */ __kernel void mat_mul_native_t_nt( TENSOR3D_T(lhs, BUFFER), - TENSOR3D_T(rhs, BUFFER), + TENSOR3D_T(rhs, RHS_TENSOR_TYPE), TENSOR3D_T(dst, BUFFER)) { const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0); @@ -354,7 +356,6 @@ __kernel void mat_mul_native_t_nt( // Compute LHS/RHS/DST matrix address lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z; - rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z; dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z; // Initialize the accumulators @@ -365,6 +366,7 @@ __kernel void mat_mul_native_t_nt( acc[i].v = 0.f; }) + const int rhs_z = z * rhs_h; int k; for(k = 0; k <= K - K0; k += K0) { @@ -383,7 +385,7 @@ __kernel void mat_mul_native_t_nt( // Load tile from the lhs/rhs tensors T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); - T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b); + T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b); #if GPU_ARCH == GPU_ARCH_MIDGARD // For explanation, see mat_mul_native_nt_t @@ -401,7 +403,6 @@ __kernel void mat_mul_native_t_nt( #endif // GPU_ARCH == GPU_ARCH_MIDGARD lhs_offset_first_element_in_bytes += K0 * lhs_stride_y; - rhs_offset_first_element_in_bytes += K0 * rhs_stride_y; } #ifdef K % K0 != 0 @@ -423,7 +424,7 @@ __kernel void mat_mul_native_t_nt( // Load tile from the lhs/rhs tensors T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); - T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b); + T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b); #if GPU_ARCH == GPU_ARCH_MIDGARD // For explanation, see mat_mul_native_nt_t @@ -438,7 +439,6 @@ __kernel void mat_mul_native_t_nt( #endif // GPU_ARCH == GPU_ARCH_MIDGARD lhs_offset_first_element_in_bytes += 1 * lhs_stride_y; - rhs_offset_first_element_in_bytes += 1 * rhs_stride_y; } #endif // K % K0 != 0 diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp index ffbaf49c02..c1f150d7aa 100644 --- a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp +++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp @@ -22,16 +22,21 @@ * SOFTWARE. */ #include "src/gpu/cl/kernels/ClNativeMatMulKernel.h" + +#include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/helpers/AutoConfiguration.h" -#include "arm_compute/core/ITensorPack.h" #include "src/common/utils/Log.h" +#include "src/core/CL/CLUtils.h" +#include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" + #include "support/Cast.h" -#include "utils/TypePrinter.h" +#include "support/StringSupport.h" namespace arm_compute { @@ -54,7 +59,7 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) if(adj_lhs) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for M0 for Lhs transposed"); } // Validate N0 @@ -88,6 +93,27 @@ Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rh return Status{}; } + +Status validate_export_to_cl_image(const ITensorInfo *rhs, const MatMulKernelInfo &matmul_kernel_info) +{ + ARM_COMPUTE_RETURN_ERROR_ON(matmul_kernel_info.export_rhs_to_cl_image && rhs->lock_paddings()); + if(matmul_kernel_info.export_rhs_to_cl_image) + { + if(matmul_kernel_info.adj_rhs) + { + const int k0 = matmul_kernel_info.k0; + ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 != 4 && k0 != 8 && k0 != 16, "K0 can only be: 4, 8, and 16 for Rhs transposed"); + } + else + { + const int n0 = matmul_kernel_info.n0; + ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 != 4 && n0 != 8 && n0 != 16, "N0 can only be: 4, 8, and 16 for Rhs non-transposed"); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_to_cl_image(rhs), "Export to CLImage is not supported for this device/configuration"); + } + + return Status {}; +} } ClNativeMatMulKernel::ClNativeMatMulKernel() { @@ -100,6 +126,7 @@ Status ClNativeMatMulKernel::validate(const ITensorInfo *lhs, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_export_to_cl_image(rhs, matmul_kernel_info)); if(output->total_size() != 0) { @@ -114,10 +141,10 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output, &compile_context, &matmul_kernel_info); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info); + ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); // output tensor auto initialization if not yet initialized auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); - ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); const int m = output->dimension(1); const int n = output->dimension(0); @@ -127,14 +154,16 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT int m0 = adj_lhs ? adjust_vec_size(matmul_kernel_info.m0, m) : std::min(matmul_kernel_info.m0, m); int n0 = adjust_vec_size(matmul_kernel_info.n0, n); + _export_rhs_to_cl_image = matmul_kernel_info.export_rhs_to_cl_image && !rhs->lock_paddings(); + // Configure kernel window Window win = calculate_max_window(*output, Steps(n0, m0)); win = win.collapse(win, Window::DimZ); IClKernel::configure_internal(win); // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding. - const unsigned int partial_store_m0 = m % m0; // M is output->dimension(1) - const unsigned int partial_store_n0 = n % n0; // N is output->dimension(0) + const unsigned int partial_store_m0 = m % m0; + const unsigned int partial_store_n0 = n % n0; CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(lhs->data_type())); @@ -144,6 +173,7 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_store_m0)); build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(partial_store_n0)); build_opts.add_option("-DK=" + support::cpp11::to_string(k)); + build_opts.add_option_if_else(_export_rhs_to_cl_image, "-DRHS_TENSOR_TYPE=IMAGE", "-DRHS_TENSOR_TYPE=BUFFER"); std::string kernel_name("mat_mul_native"); kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt"; @@ -152,6 +182,11 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT // A macro guard to compile ONLY the kernel of interest build_opts.add_option("-D" + upper_string(kernel_name)); + if(_export_rhs_to_cl_image) + { + gemm::update_padding_for_cl_image(rhs); + } + // Create kernel _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); @@ -160,12 +195,16 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT _config_id += "_"; _config_id += lower_string(string_from_data_type(lhs->data_type())); _config_id += "_"; - _config_id += support::cpp11::to_string(output->dimension(1)); + _config_id += support::cpp11::to_string(m); _config_id += "_"; - _config_id += support::cpp11::to_string(output->dimension(0)); + _config_id += support::cpp11::to_string(n); + _config_id += "_"; + _config_id += support::cpp11::to_string(k); _config_id += "_"; _config_id += support::cpp11::to_string(output->dimension(2)); _config_id += "_"; + _config_id += support::cpp11::to_string(_export_rhs_to_cl_image); + _config_id += "_"; _config_id += support::cpp11::to_string(m0); _config_id += "_"; _config_id += support::cpp11::to_string(n0); @@ -188,6 +227,20 @@ void ClNativeMatMulKernel::run_op(ITensorPack &tensors, const Window &window, cl Window window_collapsed = window.collapse(ICLKernel::window(), Window::DimZ); add_3d_tensor_nhw_argument(idx, lhs); + + cl::Image2D rhs_cl_image; + if(_export_rhs_to_cl_image) + { + const size_t image_w = rhs->info()->dimension(0) / 4; + const size_t image_h = rhs->info()->tensor_shape().total_size() / rhs->info()->dimension(0); + const TensorShape shape2d(image_w, image_h); + const size_t image_row_pitch = rhs->info()->strides_in_bytes()[1]; + + // Export cl_buffer to cl_image + rhs_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), rhs->cl_buffer(), shape2d, rhs->info()->data_type(), image_row_pitch, CLImage2DType::ReadOnly); + _kernel.setArg(idx++, rhs_cl_image); + } + add_3d_tensor_nhw_argument(idx, rhs); add_3d_tensor_nhw_argument(idx, output); diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.h b/src/gpu/cl/kernels/ClNativeMatMulKernel.h index 1cd74365df..021292a4ae 100644 --- a/src/gpu/cl/kernels/ClNativeMatMulKernel.h +++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.h @@ -63,6 +63,9 @@ public: // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; + +private: + bool _export_rhs_to_cl_image { false }; }; } // namespace kernels } // namespace opencl diff --git a/tests/datasets/LargeMatMulDataset.h b/tests/datasets/LargeMatMulDataset.h index cbc97d5e4a..b5181bc30b 100644 --- a/tests/datasets/LargeMatMulDataset.h +++ b/tests/datasets/LargeMatMulDataset.h @@ -54,6 +54,18 @@ public: } }; +class LargeMatMulDatasetRhsExportToCLImageRhsNT final : public MatMulDataset +{ +public: + // For shape choices, please refer to the explanations given in SmallMatMulDatasetRhsExportToCLImageRhsNT + LargeMatMulDatasetRhsExportToCLImageRhsNT() + { + add_config(TensorShape(21U, 13U, 3U, 2U), TensorShape(32U, 21U, 3U, 2U), TensorShape(32U, 13U, 3U, 2U)); + add_config(TensorShape(38U, 12U, 1U, 5U, 2U), TensorShape(20U, 38U, 1U, 5U, 2U), TensorShape(20U, 12U, 1U, 5U, 2U)); + add_config(TensorShape(45U, 38U, 3U, 2U, 3U), TensorShape(20U, 45U, 3U, 2U, 3U), TensorShape(20U, 38U, 3U, 2U, 3U)); + } +}; + } // namespace datasets } // namespace test } // namespace arm_compute diff --git a/tests/datasets/SmallMatMulDataset.h b/tests/datasets/SmallMatMulDataset.h index ae92b9abf5..93e5f7dc2c 100644 --- a/tests/datasets/SmallMatMulDataset.h +++ b/tests/datasets/SmallMatMulDataset.h @@ -57,6 +57,22 @@ public: } }; +class SmallMatMulDatasetRhsExportToCLImageRhsNT final : public MatMulDataset +{ +public: + // Some considerations: + // (1) N (Dimension 0 of Rhs matrix) dimension should be a multiple of 4 + // (2) Having N=20 enables us to test all possible N0 values, i.e. 4, 8, 16 + // (3) It's important to have more than one loop iterations in the K dimension + // K has been chosen in accordance with K0 + // (4) The 5-th dimension has been chosen as non-unit because export_to_cl_iamge checks + // were using dim1 * dim2 * dim3 to calculate the CLImage height; however, in our case + // the tensor can be > 4D. To stress that case, the fifth dimension is chosen to be non-unit as well + SmallMatMulDatasetRhsExportToCLImageRhsNT() + { + add_config(TensorShape(7U, 3U, 2U, 1U, 2U), TensorShape(20U, 7U, 2U, 1U, 2U), TensorShape(20U, 3U, 2U, 1U, 2U)); + } +}; } // namespace datasets } // namespace test } // namespace arm_compute diff --git a/tests/validation/CL/MatMulKernel.cpp b/tests/validation/CL/MatMulKernel.cpp index 5d2e59ab4c..59af8dba45 100644 --- a/tests/validation/CL/MatMulKernel.cpp +++ b/tests/validation/CL/MatMulKernel.cpp @@ -95,6 +95,12 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL) { MatMulKernelInfo(false, false, 9, 1, 2), true }, { MatMulKernelInfo(false, false, 3, 16, 3), true }, { MatMulKernelInfo(false, false, 7, 3, 4), true }, + { MatMulKernelInfo(false, false, 7, 3, 4, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(false, false, 7, 1, 4, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(false, false, 7, 12, 4, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(false, false, 7, 4, 4, true), true }, + { MatMulKernelInfo(false, false, 7, 8, 4, true), true }, + { MatMulKernelInfo(false, false, 7, 16, 4, true), true }, // Lhs not-transposed, Rhs transposed { MatMulKernelInfo(false, true, 0, 1, 1), false }, // M0 should be > 0 @@ -115,6 +121,12 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL) { MatMulKernelInfo(true, false, 4, 1, 22), true }, { MatMulKernelInfo(true, false, 3, 3, 3), true }, { MatMulKernelInfo(true, false, 2, 4, 8), true }, + { MatMulKernelInfo(true, false, 2, 3, 8, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(true, false, 2, 7, 8, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(true, false, 2, 5, 8, true), false }, // N0 not in {4, 8, 16} + { MatMulKernelInfo(true, false, 2, 4, 8, true), true }, + { MatMulKernelInfo(true, false, 2, 8, 8, true), true }, + { MatMulKernelInfo(true, false, 2, 16, 8, true), true }, // // Lhs transposed, Rhs-transposed { MatMulKernelInfo(true, true, 2, 1, 5), false }, // K0 should in {1, 2, 3, 4, 8, 16} @@ -134,12 +146,65 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL) const TensorInfo lhs_info = TensorInfo(TensorShape(100U, 100U), 1, DataType::F32); const TensorInfo rhs_info = TensorInfo(TensorShape(100U, 100U), 1, DataType::F32); + const bool export_to_cl_image_supported = image2d_from_buffer_supported(CLKernelLibrary::get().get_device()); for(auto &pair : supported_block_sizes) { TensorInfo output_info; Status status = ClNativeMatMulKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first); - ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS); + if(!pair.first.export_rhs_to_cl_image || export_to_cl_image_supported) + { + ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS); + } + } +} + +TEST_CASE(ExportToCLImage, framework::DatasetMode::ALL) +{ + // We skip this test if the hardware does not support exporting to CL Image + if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device())) + { + constexpr size_t pixel_size = 4; + const size_t max_image_w = pixel_size * CLKernelLibrary::get().get_device().getInfo(); + const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo(); + + using ShapeConfigurationTuple = std::tuple; + const std::vector shape_configurations = + { + // lhs_shape, rhs_shape, adj_lhs, adj_rhs, expected + // Lhs t/Nt, Rhs Nt + // Transposition of Lhs doesn't add any value to the tests, therefore always assumed false below + { TensorShape(5U, 1U), TensorShape(3U, 5U), false, false, false }, // N should be multiple of 4 + { TensorShape(5U, 1U), TensorShape(14U, 5U), false, false, false }, // N should be multiple of 4 + { TensorShape(5U, 1U), TensorShape(12U, 5U), false, false, true }, + { TensorShape(5U, 1U), TensorShape(8U, 5U), false, false, true }, + { TensorShape(5U, 1U), TensorShape(4U, 5U), false, false, true }, + { TensorShape(max_image_h + 1, 1U), TensorShape(4U, max_image_h + 1), false, false, false }, // Cannot fit into CL Image memory's height + { TensorShape(5U, 1U), TensorShape(max_image_w + 1, 5U), false, false, false }, // Cannot fit into CL Image memory's width + { TensorShape(max_image_h, 1U), TensorShape(4U, max_image_h), false, false, true }, // Barely fits into CL Image memory's height + { TensorShape(5U, 1U), TensorShape(max_image_w, 5U), false, false, true }, // Barely fits into CL Image memory's width + }; + + for(auto &tuple : shape_configurations) + { + TensorShape lhs_shape = std::get<0>(tuple); + TensorShape rhs_shape = std::get<1>(tuple); + + const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::F32); + const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::F32); + + const bool adj_lhs = std::get<2>(tuple); + const bool adj_rhs = std::get<3>(tuple); + + // We choose M0, N0, K0 equal to 4 so that they're always valid for CLImage in any combination + const MatMulKernelInfo matmul_kernel_info {adj_lhs, adj_rhs, 4, 4, 4, true /* export_rhs_to_cl_image */}; + + TensorInfo output_info; + Status status = ClNativeMatMulKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info); + + const bool expected = std::get<4>(tuple); + ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); + } } } @@ -244,68 +309,75 @@ TEST_SUITE_END() // Validate TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunTiny, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(), - framework::dataset::make("pretransose_A", { false, true })), - framework::dataset::make("pretransose_B", { false, true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("DataType", DataType::F32))) +TEST_SUITE(Buffer) +FIXTURE_DATA_TEST_CASE(RunTiny, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(), + framework::dataset::make("pretransose_A", { false, true })), + framework::dataset::make("pretransose_B", { false, true })), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("export_rhs_to_cl_image", { false })), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), - framework::dataset::make("pretransose_A", { false, true })), - framework::dataset::make("pretransose_B", { false, true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), + framework::dataset::make("pretransose_A", { false, true })), + framework::dataset::make("pretransose_B", { false, true })), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("export_rhs_to_cl_image", { false })), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { false })), framework::dataset::make("pretransose_B", { false })), m0_values_nightly_lhs_nt), n0_values_nightly_rhs_nt), k0_values_nightly_lhs_nt_rhs_nt), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { false })), framework::dataset::make("pretransose_B", { true })), m0_values_nightly_lhs_nt), n0_values_nightly_rhs_t), k0_values_nightly_rhs_t), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { true })), framework::dataset::make("pretransose_B", { false })), m0_values_nightly_lhs_t), n0_values_nightly_rhs_nt), k0_values_nightly_lhs_t_rhs_nt), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposedRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), + combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { true })), framework::dataset::make("pretransose_B", { true })), m0_values_nightly_lhs_t), n0_values_nightly_rhs_t), k0_values_nightly_rhs_t), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -313,75 +385,150 @@ FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposedRhsTransposed, CLMatMulKernelFixture } // Running High Dimensional test is enough for FP32, because we're stressing the number of dimensions, not data type or M0/N0/K0 // It's a good idea to test for each Lhs/Rhs T/NT combinations because they're different CL kernels -FIXTURE_DATA_TEST_CASE(RunHighDimensional, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(), - framework::dataset::make("pretransose_A", { false, true })), - framework::dataset::make("pretransose_B", { false, true })), - framework::dataset::make("M0", { 2 })), - framework::dataset::make("N0", { 2 })), - framework::dataset::make("K0", { 2 })), - framework::dataset::make("DataType", DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunHighDimensional, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(), + framework::dataset::make("pretransose_A", { false, true })), + framework::dataset::make("pretransose_B", { false, true })), + framework::dataset::make("M0", { 2 })), + framework::dataset::make("N0", { 2 })), + framework::dataset::make("K0", { 2 })), + framework::dataset::make("export_rhs_to_cl_image", { false })), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); } +TEST_SUITE_END() // Buffer + +TEST_SUITE(ExportRhsToCLImage) +FIXTURE_DATA_TEST_CASE(RunSmallRhsNotTransposed, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDatasetRhsExportToCLImageRhsNT(), + framework::dataset::make("pretransose_A", { true, false })), + framework::dataset::make("pretransose_B", { false })), + framework::dataset::make("M0", { 2 })), + framework::dataset::make("N0", { 4, 8, 16 })), + framework::dataset::make("K0", { 2, 4 })), + framework::dataset::make("export_rhs_to_cl_image", { true })), + framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + if(_device_supports_export_to_cl_image) + { + validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); + } +} +FIXTURE_DATA_TEST_CASE(RunLargeRhsNotTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDatasetRhsExportToCLImageRhsNT(), + framework::dataset::make("pretransose_A", { true, false })), + framework::dataset::make("pretransose_B", { false })), + framework::dataset::make("M0", { 2 })), // Choices of M0 does not matter much because it's related to Lhs tensor + framework::dataset::make("N0", { 4, 8, 16 })), + framework::dataset::make("K0", { 1, 2, 3, 4 })), + framework::dataset::make("export_rhs_to_cl_image", { true })), + framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + if(_device_supports_export_to_cl_image) + { + validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32); + } +} +TEST_SUITE_END() // ExportRhsToCLImage TEST_SUITE_END() // FP32 TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), - framework::dataset::make("pretransose_A", { false, true })), - framework::dataset::make("pretransose_B", { false, true })), - m0_values_precommit), - n0_values_precommit), - k0_values_precommit), - framework::dataset::make("DataType", DataType::F16))) +TEST_SUITE(Buffer) +FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(), + framework::dataset::make("pretransose_A", { false, true })), + framework::dataset::make("pretransose_B", { false, true })), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("export_rhs_to_cl_image", { false })), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { false })), framework::dataset::make("pretransose_B", { false })), m0_values_nightly_lhs_nt), n0_values_nightly_rhs_nt), k0_values_nightly_lhs_nt_rhs_nt), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { false })), framework::dataset::make("pretransose_B", { true })), m0_values_nightly_lhs_nt), n0_values_nightly_rhs_t), k0_values_nightly_rhs_t), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { true })), framework::dataset::make("pretransose_B", { false })), m0_values_nightly_lhs_t), n0_values_nightly_rhs_nt), k0_values_nightly_lhs_t_rhs_nt), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposedRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), +FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposedRhsTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(), framework::dataset::make("pretransose_A", { true })), framework::dataset::make("pretransose_B", { true })), m0_values_nightly_lhs_t), n0_values_nightly_rhs_t), k0_values_nightly_rhs_t), + framework::dataset::make("export_rhs_to_cl_image", { false })), framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); } +TEST_SUITE_END() // Buffer + +TEST_SUITE(ExportRhsToCLImage) +FIXTURE_DATA_TEST_CASE(RunSmallRhsCLImageRhsNotTransposed, CLMatMulKernelFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDatasetRhsExportToCLImageRhsNT(), + framework::dataset::make("pretransose_A", { true, false })), + framework::dataset::make("pretransose_B", { false })), + framework::dataset::make("M0", { 2 })), + framework::dataset::make("N0", { 4, 8, 16 })), + framework::dataset::make("K0", { 2, 4 })), + framework::dataset::make("export_rhs_to_cl_image", { true })), + framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + if(_device_supports_export_to_cl_image) + { + validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); + } +} +FIXTURE_DATA_TEST_CASE(RunLargeRhsCLImageRhsNotTransposed, CLMatMulKernelFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDatasetRhsExportToCLImageRhsNT(), + framework::dataset::make("pretransose_A", { true, false })), + framework::dataset::make("pretransose_B", { false })), + framework::dataset::make("M0", { 2 })), // Choices of M0 does not matter much because it's related to Lhs tensor + framework::dataset::make("N0", { 4, 8, 16 })), + framework::dataset::make("K0", { 1, 2, 3, 4 })), + framework::dataset::make("export_rhs_to_cl_image", { true })), + framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + if(_device_supports_export_to_cl_image) + { + validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16); + } +} +TEST_SUITE_END() // ExportRhsToCLImage TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // MatMulKernel diff --git a/tests/validation/fixtures/MatMulKernelFixture.h b/tests/validation/fixtures/MatMulKernelFixture.h index 459564618f..c131fea7fa 100644 --- a/tests/validation/fixtures/MatMulKernelFixture.h +++ b/tests/validation/fixtures/MatMulKernelFixture.h @@ -48,7 +48,7 @@ class MatMulKernelValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, const int M0, const int N0, const int K0, DataType data_type) + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, const int M0, const int N0, const int K0, bool export_rhs_to_cl_image, DataType data_type) { // For brevity, the input shapes are assumed to be not-transposed for both Lhs and Rhs matrices. if(pretranspose_a) @@ -61,8 +61,13 @@ public: permute(shape_b, PermutationVector(1U, 0U)); } - _target = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, data_type); - _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type); + _device_supports_export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device()); + + if(!export_rhs_to_cl_image || _device_supports_export_to_cl_image) + { + _target = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type); + _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type); + } } protected: @@ -89,7 +94,7 @@ protected: } CLTensor compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, const int M0, const int N0, const int K0, - DataType data_type) + bool export_rhs_to_cl_image, DataType data_type) { // Create tensors CLTensor a = create_tensor(shape_a, data_type, 1); @@ -103,6 +108,7 @@ protected: matmul_info.m0 = M0; matmul_info.n0 = N0; matmul_info.k0 = K0; + matmul_info.export_rhs_to_cl_image = export_rhs_to_cl_image; matMul.configure(a.info(), b.info(), dst.info(), matmul_info); ARM_COMPUTE_ASSERT(a.info()->is_resizable()); @@ -195,6 +201,7 @@ protected: CLTensor _target{}; SimpleTensor _reference{}; + bool _device_supports_export_to_cl_image { true }; }; } // namespace validation -- cgit v1.2.1