diff options
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h | 1 | ||||
-rw-r--r-- | arm_compute/core/KernelDescriptors.h | 4 | ||||
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMM.h | 4 |
3 files changed, 8 insertions, 1 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h index fc21f2a0f6..eab7fd219e 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h @@ -162,6 +162,7 @@ private: bool _add_bias; bool _broadcast_bias; bool _export_to_cl_image; + bool _has_pad_y; }; } // namespace arm_compute #endif /*ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H*/ diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h index 1ee1686fb1..ea46bfa5a6 100644 --- a/arm_compute/core/KernelDescriptors.h +++ b/arm_compute/core/KernelDescriptors.h @@ -64,6 +64,7 @@ struct GEMMKernelInfo bool ireinterpret_input_as_3d, bool ibroadcast_bias, bool ifp_mixed_precision, + bool ihas_pad_y, ActivationLayerInfo iactivation_info, int inmult_transpose1xW_width, int imult_interleave4x4_height, @@ -72,7 +73,7 @@ struct GEMMKernelInfo int32_t ina_offset, int32_t inb_offset) : m(im), n(in), k(ik), depth_output_gemm3d(idepth_output_gemm3d), reinterpret_input_as_3d(ireinterpret_input_as_3d), broadcast_bias(ibroadcast_bias), fp_mixed_precision(ifp_mixed_precision), - activation_info(iactivation_info), mult_transpose1xW_width(inmult_transpose1xW_width), mult_interleave4x4_height(imult_interleave4x4_height), lhs_info(ilhs_info), rhs_info(irhs_info), + has_pad_y(ihas_pad_y), activation_info(iactivation_info), mult_transpose1xW_width(inmult_transpose1xW_width), mult_interleave4x4_height(imult_interleave4x4_height), lhs_info(ilhs_info), rhs_info(irhs_info), a_offset(ina_offset), b_offset(inb_offset) { } @@ -84,6 +85,7 @@ struct GEMMKernelInfo bool reinterpret_input_as_3d{ false }; /**< Flag used to reinterpret the input as 3D */ bool broadcast_bias{ false }; /**< Flag used to broadcast the bias addition */ bool fp_mixed_precision{ false }; /**< Flag used to indicate wider accumulators (32 bit instead of 16 for FP16). */ + bool has_pad_y{ false }; /**< Flag used to indicate if the input/output tensors have internal pad on the y direction */ ActivationLayerInfo activation_info{}; /**< Activation function to perform after the matrix multiplication */ int mult_transpose1xW_width{ 1 }; /**< Multiplication factor for the width of the 1xW transposed block */ int mult_interleave4x4_height{ 1 }; /**< Multiplication factor for the height of the 4x4 interleaved block */ diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index 6e9cf0e2ca..92f9736e35 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -206,11 +206,15 @@ private: weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed; CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; + CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_fallback_kernel; CLTensor _tmp_a; CLTensor _tmp_b; const ICLTensor *_original_b; + const ICLTensor *_lhs; + ICLTensor *_dst; bool _reshape_b_only_on_first_run; bool _is_prepared; + bool _has_pad_y; CLGEMMKernelType _gemm_kernel_type; }; } // namespace arm_compute |