diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp index 02c5754672..a0eb3f2853 100644 --- a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp @@ -100,7 +100,8 @@ ClMatMulLowpNativeKernel::ClMatMulLowpNativeKernel() { _type = CLKernelType::GEMM; } -Status ClMatMulLowpNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info) +Status ClMatMulLowpNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *bias, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, + const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); @@ -111,24 +112,32 @@ Status ClMatMulLowpNativeKernel::validate(const ITensorInfo *lhs, const ITensorI ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_info.activation() != ActivationFunction::IDENTITY && act_info.activation() != ActivationFunction::RELU && act_info.activation() != ActivationFunction::LU_BOUNDED_RELU && act_info.activation() != ActivationFunction::BOUNDED_RELU), "Activation Function specified is unsupported."); + const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); if(dst->total_size() != 0) { - const TensorInfo tensor_info_output = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + const TensorInfo tensor_info_output = dst->clone()->set_tensor_shape(expected_output_shape); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); } + if(bias != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(expected_output_shape[0] != bias->dimension(0)); + } + return Status{}; } -void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, +void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *bias, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst, &compile_context, &matmul_kernel_info); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_kernel_info); - ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, dst, matmul_kernel_info)); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, bias, dst, matmul_kernel_info); + ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, bias, dst, matmul_kernel_info)); - // output tensor auto initialization if not yet initialized + // dst tensor auto initialization if not yet initialized auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); const int m = dst->dimension(1); @@ -172,7 +181,8 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context // Note : Offset is not negated, unlike gemmlowp kernels build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(lqinfo.offset)); build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(rqinfo.offset)); - build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two) + build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); + build_opts.add_option_if(bias != nullptr, "-DBIAS"); // Floating point boundaries are quantized prior to being passed as arguments. // Note: We expect the input and output tensors to always adopt a per-tensor quantization approach @@ -222,17 +232,22 @@ void ClMatMulLowpNativeKernel::run_op(ITensorPack &tensors, const Window &window ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - const ICLTensor *lhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0)); - const ICLTensor *rhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1)); - ICLTensor *dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); + const ICLTensor *lhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const ICLTensor *rhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const ICLTensor *bias = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_2)); + ICLTensor *dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, bias, dst); unsigned int idx = 0; Window window_collapsed = window.collapse(ICLKernel::window(), Window::DimZ); add_3d_tensor_nhw_argument(idx, lhs); add_3d_tensor_nhw_argument(idx, rhs); + if(bias != nullptr) + { + add_3d_tensor_nhw_argument(idx, bias); + } add_3d_tensor_nhw_argument(idx, dst); enqueue(queue, *this, window_collapsed, lws_hint()); |