From 94abde4f4e98f6f1adb5c46b194527f34a8ea07d Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Thu, 25 May 2023 16:48:43 +0100 Subject: Add Fused Activation to OpenCL MatMul - Added fused activation to MatMul function interface - Added fused activation to CL backend - Includes tests for supported Activation Functions in MatMul Resolves: [COMPMID-6192] Signed-off-by: Mohammed Suhail Munshi Change-Id: Ie103212b600b60699eaf6a6394d609e6e1f5aba6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/522465 Comments-Addressed: bsgcomp Reviewed-by: Viet-Hoa Do Tested-by: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9714 Comments-Addressed: Arm Jenkins Reviewed-by: Jakub Sujak Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp | 50 ++++++++++++++----------- src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h | 21 ++++++----- src/gpu/cl/kernels/ClMatMulNativeKernel.cpp | 17 ++++++--- src/gpu/cl/kernels/ClMatMulNativeKernel.h | 20 +++++----- src/gpu/cl/operators/ClMatMul.cpp | 16 ++++---- src/gpu/cl/operators/ClMatMul.h | 18 +++++---- 6 files changed, 81 insertions(+), 61 deletions(-) (limited to 'src/gpu') diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp index d5ecdf7dd2..9bbec908a3 100644 --- a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp @@ -98,34 +98,36 @@ ClMatMulLowpNativeKernel::ClMatMulLowpNativeKernel() { _type = CLKernelType::GEMM; } -Status ClMatMulLowpNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *output, const MatMulKernelInfo &matmul_kernel_info) +Status ClMatMulLowpNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, output); + ARM_COMPUTE_UNUSED(act_info); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); - if(output->total_size() != 0) + if(dst->total_size() != 0) { - const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output); + const TensorInfo tensor_info_output = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); } return Status{}; } -void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *output, const MatMulKernelInfo &matmul_kernel_info) +void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, + const ActivationLayerInfo &act_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output, &compile_context, &matmul_kernel_info); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info); - ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst, &compile_context, &matmul_kernel_info); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_kernel_info); + ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, dst, matmul_kernel_info)); // output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); + auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); - const int m = output->dimension(1); - const int n = output->dimension(0); + const int m = dst->dimension(1); + const int n = dst->dimension(0); const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); const bool adj_lhs = matmul_kernel_info.adj_lhs; @@ -133,7 +135,7 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context int n0 = adjust_vec_size(matmul_kernel_info.n0, n); // Configure kernel window - Window win = calculate_max_window(*output, Steps(n0, m0)); + Window win = calculate_max_window(*dst, Steps(n0, m0)); win = win.collapse(win, Window::DimZ); IClKernel::configure_internal(win); @@ -152,7 +154,7 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context const UniformQuantizationInfo lqinfo = lhs->quantization_info().uniform(); const UniformQuantizationInfo rqinfo = rhs->quantization_info().uniform(); - const UniformQuantizationInfo dqinfo = output->quantization_info().uniform(); + const UniformQuantizationInfo dqinfo = dst->quantization_info().uniform(); float multiplier = lqinfo.scale * rqinfo.scale / dqinfo.scale; int output_multiplier = 0; @@ -166,6 +168,10 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(-rqinfo.offset)); // Note this is passed as negative to maintain similarity with CLDirectConv2D build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two) + build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(act_info.a()))); + build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(act_info.b()))); + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation()))); + std::string kernel_name("mat_mul_native_quantized"); kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt"; kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt"; @@ -177,7 +183,7 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); // Set config_id for enabling LWS tuning - const size_t number_of_batches = output->tensor_shape().total_size() / (m * n); + const size_t number_of_batches = dst->tensor_shape().total_size() / (m * n); _config_id = kernel_name; _config_id += "_"; @@ -203,18 +209,18 @@ void ClMatMulLowpNativeKernel::run_op(ITensorPack &tensors, const Window &window ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - const ICLTensor *lhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); - const ICLTensor *rhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); - ICLTensor *output = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); - ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output); + const ICLTensor *lhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const ICLTensor *rhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + ICLTensor *dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst); unsigned int idx = 0; Window window_collapsed = window.collapse(ICLKernel::window(), Window::DimZ); add_3d_tensor_nhw_argument(idx, lhs); add_3d_tensor_nhw_argument(idx, rhs); - add_3d_tensor_nhw_argument(idx, output); + add_3d_tensor_nhw_argument(idx, dst); enqueue(queue, *this, window_collapsed, lws_hint()); } diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h index d70ff30b91..67d1a6601f 100644 --- a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h +++ b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h @@ -24,6 +24,7 @@ #ifndef ACL_SRC_GPU_CL_KERNELS_CLMATMULLOWPNATIVEKERNEL #define ACL_SRC_GPU_CL_KERNELS_CLMATMULLOWPNATIVEKERNEL +#include "arm_compute/core/ActivationLayerInfo.h" #include "src/core/common/Macros.h" #include "src/gpu/cl/ClCompileContext.h" #include "src/gpu/cl/IClKernel.h" @@ -43,22 +44,24 @@ public: ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClMatMulLowpNativeKernel); /** Initialise the kernel's input and output. * - * @param[in] compile_context The compile context to be used. - * @param[in] lhs Input tensor for the LHS matrix. Data type supported: QASYMM8_SIGNED/QASYMM8. - * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. - * @param[in] rhs Input tensor for the RHS matrix. Data type supported: same as @p lhs. - * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. - * @param[out] dst Output tensor info. Data type supported: same as @p lhs - * @param[in] matmul_info Attributes for Batch MatMul Kernel + * @param[in] compile_context The compile context to be used. + * @param[in] lhs Input tensor for the LHS matrix. Data type supported: QASYMM8_SIGNED/QASYMM8. + * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. + * @param[in] rhs Input tensor for the RHS matrix. Data type supported: same as @p lhs. + * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. + * @param[out] dst Output tensor info. Data type supported: same as @p lhs + * @param[in] matmul_kernel_info Attributes for Batch MatMul Kernel + * @param[in] act_info Class containing information about fused activation function. */ - void configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_info); + void configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration * * Similar to @ref ClMatMulLowpNativeKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_info); + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; diff --git a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp index 8f53c1998f..205396a639 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp @@ -112,7 +112,7 @@ Status validate_export_to_cl_image(const ITensorInfo *rhs, const MatMulKernelInf ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_to_cl_image(rhs), "Export to CLImage is not supported for this device/configuration"); } - return Status {}; + return Status{}; } } ClMatMulNativeKernel::ClMatMulNativeKernel() @@ -120,8 +120,9 @@ ClMatMulNativeKernel::ClMatMulNativeKernel() _type = CLKernelType::GEMM; } -Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info) +Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_UNUSED(act_info); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); @@ -138,7 +139,8 @@ Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo return Status{}; } -void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info) +void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, + const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst, &compile_context, &matmul_kernel_info); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_kernel_info); @@ -176,6 +178,11 @@ void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, IT build_opts.add_option("-DK=" + support::cpp11::to_string(k)); build_opts.add_option_if_else(_export_rhs_to_cl_image, "-DRHS_TENSOR_TYPE=IMAGE", "-DRHS_TENSOR_TYPE=BUFFER"); + // Define values for activation function + build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(act_info.a()))); + build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(act_info.b()))); + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation()))); + std::string kernel_name("mat_mul_native"); kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt"; kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt"; @@ -218,8 +225,8 @@ void ClMatMulNativeKernel::run_op(ITensorPack &tensors, const Window &window, cl ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - const ICLTensor *lhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); - const ICLTensor *rhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const ICLTensor *lhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const ICLTensor *rhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); ICLTensor *dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst); diff --git a/src/gpu/cl/kernels/ClMatMulNativeKernel.h b/src/gpu/cl/kernels/ClMatMulNativeKernel.h index f706256e31..02d8ac3067 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeKernel.h +++ b/src/gpu/cl/kernels/ClMatMulNativeKernel.h @@ -42,22 +42,24 @@ public: ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClMatMulNativeKernel); /** Initialise the kernel's input and output. * - * @param[in] compile_context The compile context to be used. - * @param[in] lhs Input tensor for the LHS matrix. Data type supported: F32/F16. - * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. - * @param[in] rhs Input tensor for the RHS matrix. Data type supported: same as @p lhs. - * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. - * @param[out] dst Output tensor info. Data type supported: same as @p lhs - * @param[in] matmul_info Attributes for Batch MatMul Kernel + * @param[in] compile_context The compile context to be used. + * @param[in] lhs Input tensor for the LHS matrix. Data type supported: F32/F16. + * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. + * @param[in] rhs Input tensor for the RHS matrix. Data type supported: same as @p lhs. + * Dimensions above 2 are collapsed onto dimension 2 and represent the batch. + * @param[out] dst Output tensor info. Data type supported: same as @p lhs + * @param[in] matmul_kernel_info Attributes for Batch MatMul Kernel + * @param[in] act_info Specifies activation function to use after Matrix multiplication. Default is Identity function. */ - void configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_info); + void configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration * * Similar to @ref ClMatMulNativeKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_info); + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; diff --git a/src/gpu/cl/operators/ClMatMul.cpp b/src/gpu/cl/operators/ClMatMul.cpp index 3822c16aa1..c453761a8e 100644 --- a/src/gpu/cl/operators/ClMatMul.cpp +++ b/src/gpu/cl/operators/ClMatMul.cpp @@ -47,7 +47,7 @@ ClMatMul::ClMatMul() { } -Status ClMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info) +Status ClMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); @@ -57,15 +57,15 @@ Status ClMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const std::unique_ptr t = ClMatMulNativeKernelConfigurationFactory::create(gpu_target); - MatMulKernelInfo kernel_info = t->configure(lhs, rhs, matmul_info); + const MatMulKernelInfo kernel_info = t->configure(lhs, rhs, matmul_info); - bool is_quantized = is_data_type_quantized_asymmetric(lhs->data_type()); + const bool is_quantized = is_data_type_quantized_asymmetric(lhs->data_type()); - return is_quantized ? ClMatMulLowpNativeKernel::validate(lhs, rhs, dst, kernel_info) : - ClMatMulNativeKernel::validate(lhs, rhs, dst, kernel_info); + return is_quantized ? ClMatMulLowpNativeKernel::validate(lhs, rhs, dst, kernel_info, act_info) : + ClMatMulNativeKernel::validate(lhs, rhs, dst, kernel_info, act_info); } -void ClMatMul::configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info) +void ClMatMul::configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_info); @@ -86,14 +86,14 @@ void ClMatMul::configure(const CLCompileContext &compile_context, ITensorInfo *l _matmul_lowp_native_kernel->set_target(gpu_target); // Configure the low-precision native matrix multiply kernel - _matmul_lowp_native_kernel->configure(compile_context, lhs, rhs, dst, kernel_info); + _matmul_lowp_native_kernel->configure(compile_context, lhs, rhs, dst, kernel_info, act_info); } else { _matmul_native_kernel->set_target(gpu_target); // Configure the native matrix multiply kernel - _matmul_native_kernel->configure(compile_context, lhs, rhs, dst, kernel_info); + _matmul_native_kernel->configure(compile_context, lhs, rhs, dst, kernel_info, act_info); } } diff --git a/src/gpu/cl/operators/ClMatMul.h b/src/gpu/cl/operators/ClMatMul.h index 6aba801301..9dce5288e6 100644 --- a/src/gpu/cl/operators/ClMatMul.h +++ b/src/gpu/cl/operators/ClMatMul.h @@ -21,14 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL -#define ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL +#ifndef ACL_SRC_GPU_CL_OPERATORS_CLMATMUL +#define ACL_SRC_GPU_CL_OPERATORS_CLMATMUL #include "arm_compute/core/ActivationLayerInfo.h" #include "arm_compute/core/MatMulInfo.h" #include "src/gpu/cl/IClOperator.h" -#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h" #include "src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h" +#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h" #include @@ -71,24 +71,26 @@ public: * @param[in] rhs Right-hand side tensor info. Data types supported: same as @p lhs. * @param[out] dst Output tensor to store the result of the batched matrix multiplication. Data types supported: same as @p lhs. * @param[in] matmul_info Contains MatMul operation information described in @ref MatMulInfo. + * @param[in] act_info Class containing information about fused activation function. */ - void configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info); + void configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration * * Similar to @ref ClMatMul::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info); + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run(ITensorPack &tensors) override; private: - std::unique_ptr _matmul_native_kernel{nullptr}; - std::unique_ptr _matmul_lowp_native_kernel{nullptr}; + std::unique_ptr _matmul_native_kernel{ nullptr }; + std::unique_ptr _matmul_lowp_native_kernel{ nullptr }; bool _is_quantized{ false }; }; } // namespace opencl } // namespace arm_compute -#endif /* ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL */ +#endif /* ACL_SRC_GPU_CL_OPERATORS_CLMATMUL */ -- cgit v1.2.1