diff options
7 files changed, 88 insertions, 53 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h index ce37787862..797bda86cf 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -32,7 +32,7 @@ class ICLTensor; /** OpenCL kernel to multiply two input matrices "A" and "B" . All elements of the output matrix will be multiplied by alpha * - * @note If the input tensors @p input0 and @p input1 have been reshaped respectively with @ref CLGEMMInterleave4x4Kernel" and @ref CLGEMMReshapeRHSMatrixKernel, + * @note If the input tensors @p input0 and @p input1 have been reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel" and @ref CLGEMMReshapeRHSMatrixKernel, * the flag @p is_interleaved_transposed must be set to true * * @attention The second input tensor must have at least 2 dimensions (matrix) @@ -57,7 +57,7 @@ public: * @param[in] input1 Input tensor containing the Matrix B. Data type supported: same as @p input0 * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product - * @param[in] is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMReshapeRHSMatrixKernel + * @param[in] is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy * @@ -70,7 +70,7 @@ public: * @param[in] input1 Input tensor containing the Matrix B. Data type supported: same as @p input0 * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product - * @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMReshapeRHSMatrixKernel + * @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel * @param[in] reshape_info GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped * @param[in] gpu_target GPU Target * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index c4accde23d..624df33ef6 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2018 ARM Limited. + * Copyright (c) 2016-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,6 @@ #ifndef __ARM_COMPUTE_CLGEMM_H__ #define __ARM_COMPUTE_CLGEMM_H__ -#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h" @@ -41,8 +40,7 @@ class ICLTensor; /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: * - * -# @ref CLGEMMInterleave4x4Kernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target is NOT Mali-G76) - * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target IS Mali-G76) + * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model) * -# @ref CLGEMMReshapeRHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model) * -# @ref CLGEMMMatrixMultiplyKernel (if GPU target is NOT G76 or if the reshaped GEMM is NOT selected) * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target IS Mali-G76) @@ -86,13 +84,13 @@ public: void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMM. * - * @param[in] a First input tensor info (Matrix or Vector A). Data types supported: F16/F32 - * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a. - * @param[in] c Third input tensor info (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. - * @param[out] output Output tensor info. Data type supported: same as @p a - * @param[in] alpha Weight of the matrix product - * @param[in] beta Weight of matrix C - * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and + * @param[in] a First input tensor info (Matrix or Vector A). Data types supported: F16/F32 + * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a. + * @param[in] c Third input tensor info (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. + * @param[in] output Output tensor info. Data type supported: same as @p a + * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of matrix C + * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should happen only for the first run * * @return a status @@ -105,7 +103,6 @@ public: private: CLMemoryGroup _memory_group; - CLGEMMInterleave4x4Kernel _interleave_kernel; // TODO - COMPMID-1835: Remove this kernel and use CLGEMMReshapeLHSMatrixKernel CLGEMMMatrixMultiplyKernel _mm_kernel; CLGEMMMatrixAdditionKernel _ma_kernel; CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; @@ -118,8 +115,8 @@ private: bool _run_addition; bool _reshape_b_only_on_first_run; bool _is_prepared; - bool _is_G76_path; // TODO: To be removed once completed COMPMID-1835 and COMPMID-1836 + bool _is_G76_path; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMM_H__ */ diff --git a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h index 141354e723..72d91070f8 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,11 +24,11 @@ #ifndef __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H__ #define __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H__ -#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" #include "arm_compute/runtime/CL/CLMemoryGroup.h" #include "arm_compute/runtime/CL/CLTensor.h" @@ -41,7 +41,7 @@ class ICLTensor; /** Basic function to execute GEMMLowpMatrixMultiplyCore on OpenCL. This function calls the following OpenCL kernels: * - * -# @ref CLGEMMInterleave4x4Kernel (if the output tensor is a matrix) + * -# @ref CLGEMMReshapeLHSMatrixKernel (if the output tensor is a matrix) * -# @ref CLGEMMReshapeRHSMatrixKernel (if the output tensor is a matrix) * -# @ref CLGEMMLowpMatrixMultiplyKernel * -# @ref CLGEMMLowpMatrixAReductionKernel (if the offset of matrix B is not 0) @@ -101,7 +101,7 @@ public: private: CLMemoryGroup _memory_group; CLGEMMLowpMatrixMultiplyKernel _mm_kernel; - CLGEMMInterleave4x4Kernel _mtx_a_reshape_kernel; + CLGEMMReshapeLHSMatrixKernel _mtx_a_reshape_kernel; CLGEMMReshapeRHSMatrixKernel _mtx_b_reshape_kernel; CLGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; CLGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp index 66fafe4de5..2c072a8ba0 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -72,16 +72,24 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, else { GEMMRHSMatrixInfo rhs_info; + GEMMLHSMatrixInfo lhs_info; const int m = reshape_info.m(); const int n = reshape_info.n(); const int k = reshape_info.k(); const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width(); const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height(); - rhs_info.n0 = 16 / input1->element_size(); - rhs_info.k0 = 1; - rhs_info.h0 = mult_transpose1xW_width; - rhs_info.interleave = false; - rhs_info.transpose = false; + const bool unroll_block = dot8_supported(CLKernelLibrary::get().get_device()); + + rhs_info.n0 = 16 / input1->element_size(); + rhs_info.k0 = 1; + rhs_info.h0 = mult_transpose1xW_width; + rhs_info.interleave = false; + rhs_info.transpose = false; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = unroll_block; TensorShape tensor_shape0{ input0->tensor_shape() }; tensor_shape0.set(0, k); @@ -94,7 +102,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); - const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); + const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_lhs_reshaped_shape(tensor_info0, lhs_info)); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 69455cf419..89fe7a4650 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,7 +40,8 @@ #include <set> #include <string> -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; namespace @@ -67,6 +68,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i else { GEMMRHSMatrixInfo rhs_info; + GEMMLHSMatrixInfo lhs_info; const int m = reshape_info.m(); const int n = reshape_info.n(); const int k = reshape_info.k(); @@ -77,6 +79,11 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i rhs_info.h0 = mult_transpose1xW_width; rhs_info.interleave = false; rhs_info.transpose = false; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = true; TensorShape tensor_shape0{ input0->tensor_shape() }; tensor_shape0.set(0, k); @@ -89,7 +96,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); - const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); + const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_lhs_reshaped_shape(tensor_info0, lhs_info)); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); @@ -439,3 +446,4 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que } while(window.slide_window_slice_3D(slice)); } +} // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 9048b85114..a3612f3b5d 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,8 @@ #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/ITensorAllocator.h" -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; namespace @@ -117,7 +118,6 @@ inline void select_gemm_configuration(unsigned int m, unsigned int n, GEMMLHSMat CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager) : _memory_group(std::move(memory_manager)), - _interleave_kernel(), _mm_kernel(), _ma_kernel(), _reshape_lhs_kernel(), @@ -153,7 +153,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * const GPUTarget gpu_target = CLScheduler::get().target(); // Set the target for the kernels - _interleave_kernel.set_target(gpu_target); + _reshape_lhs_kernel.set_target(gpu_target); _mm_kernel.set_target(gpu_target); // Arguments used by GEMMReshapeInfo @@ -180,6 +180,13 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * rhs_info.interleave = false; rhs_info.transpose = false; + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = true; + // Check if we need to reshape the matrix A and matrix B _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target); @@ -219,8 +226,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * else { // Configure interleave kernel - _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()); - + _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); // Configure transpose kernel _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); } @@ -296,6 +302,13 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso rhs_info.interleave = false; rhs_info.transpose = false; + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = true; + // Check if we need to reshape the matrix A and matrix B const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target); @@ -335,8 +348,8 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso else { // Validate interleave kernel - auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()))); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())); + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d())); // Validate transpose kernel auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); @@ -367,14 +380,7 @@ void CLGEMM::run() if(_is_interleaved_transposed) { // Run interleave kernel - if(_is_G76_path) - { - CLScheduler::get().enqueue(_reshape_lhs_kernel, false); - } - else - { - CLScheduler::get().enqueue(_interleave_kernel, false); - } + CLScheduler::get().enqueue(_reshape_lhs_kernel, false); if(!_reshape_b_only_on_first_run) { @@ -417,3 +423,4 @@ void CLGEMM::prepare() _is_prepared = true; } } +} // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index cf20bc6a7a..edb3107173 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -32,7 +32,8 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; namespace @@ -109,6 +110,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor const ICLTensor *matrix_a = a; const ICLTensor *matrix_b = b; GEMMRHSMatrixInfo rhs_info; + GEMMLHSMatrixInfo lhs_info; // Arguments used by GEMMReshapeInfo // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo @@ -126,6 +128,11 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor rhs_info.h0 = mult_transpose1xW_width; rhs_info.interleave = false; rhs_info.transpose = false; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = unroll_block; // Check if we need to reshape the matrix A and matrix B _is_interleaved_transposed = is_interleaved_transposed(m, n, k, _reshape_b_only_on_first_run, gpu_target); @@ -145,7 +152,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor } // Configure interleave kernel - _mtx_a_reshape_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d(), unroll_block); + _mtx_a_reshape_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); // Configure transpose kernel _mtx_b_reshape_kernel.configure(b, &_tmp_b, rhs_info); @@ -242,8 +249,10 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso TensorInfo tmp_a_info{}; TensorInfo tmp_b_info{}; GEMMRHSMatrixInfo rhs_info; + GEMMLHSMatrixInfo lhs_info; bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const bool unroll_block = dot8_supported(CLKernelLibrary::get().get_device()); const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); const int n = b->dimension(0); const int k = a->dimension(0); @@ -255,6 +264,11 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso rhs_info.h0 = mult_transpose1xW_width; rhs_info.interleave = false; rhs_info.transpose = false; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = unroll_block; bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target()); @@ -272,8 +286,8 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso matrix_b_info = &tmp_b_info; // Validate interleave kernel - auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()))); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())); + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d())); // Validate transpose kernel @@ -408,3 +422,4 @@ void CLGEMMLowpMatrixMultiplyCore::prepare() _is_prepared = true; } } +} // namespace arm_compute |