diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp index 1521973d55..e70e5d5ea5 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.cpp @@ -24,6 +24,9 @@ #if defined(ENABLE_EXPERIMENTAL_DYNAMIC_FUSION) #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h" +#include "arm_compute/core/TensorInfo.h" +#include "src/core/AccessWindowStatic.h" +#include "src/core/helpers/WindowHelpers.h" namespace arm_compute { @@ -41,6 +44,92 @@ std::set<std::string> ClGemmNativeKernelComponent::get_headers_list() const return std::set<std::string> { "./common/experimental/gemm_fused_post_ops/act_eltwise_op_act/fp_post_ops_act_eltwise_op_act.h", "gemm_helpers.h", "repeat.h" }; } +Window ClGemmNativeKernelComponent::get_window() const +{ + ITensorInfo *lhs_info = _blueprint->impl().get_kernel_argument_info(_lhs.arg_id); + ITensorInfo *rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); + ITensorInfo *bias_info = _blueprint->impl().get_kernel_argument_info(_bias.arg_id); + ITensorInfo *dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs_info, rhs_info, dst_info); + + bool reinterpret_input_as_3d = _desc.reinterpret_input_as_3d; + bool reinterpret_output_as_3d = _desc.depth_output_gemm3d != 0; + + Window win{}; + Window win_out{}; + bool window_changed = false; + + // In case both input and dst have to be reinterpreted as 3D tensors, + // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. + if(reinterpret_input_as_3d == reinterpret_output_as_3d) + { + reinterpret_output_as_3d = false; + } + + // activation_layer is set to dummy because it's required by GEMMKernelInfo, but it's not used in shape calculation + GEMMKernelInfo gemm_info(_desc.m, _desc.n, _desc.k, _desc.depth_output_gemm3d, _desc.reinterpret_input_as_3d, + _desc.broadcast_bias, _desc.fp_mixed_precision, _desc.has_pad_y, ActivationLayerInfo(), _desc.nmult_transpose1xW_width, + _desc.mult_interleave4x4_height, _desc.lhs_info, _desc.rhs_info, _desc.a_offset, _desc.b_offset); + + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst_info, lhs_info->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*lhs_info, *rhs_info, gemm_info))); + + TensorInfo tmp_info(*dst_info); + + if(reinterpret_output_as_3d) + { + // Since the dst tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM, + // the window needs to be constructed on the 2D collapsed version of the tensor + TensorShape tmp_shape(dst_info->tensor_shape()); + tmp_shape.collapse(2U, 1U); + tmp_info.set_tensor_shape(tmp_shape); + } + + win = calculate_max_window(tmp_info, Steps(_desc.rhs_info.n0, _desc.lhs_info.m0)); + win_out = calculate_max_window(*dst_info, Steps(_desc.rhs_info.n0, _desc.lhs_info.m0)); + + AccessWindowStatic src0_access(lhs_info, 0, 0, + lhs_info->dimension(0), + lhs_info->dimension(1)); + AccessWindowStatic src1_access(rhs_info, 0, 0, + ceil_to_multiple(rhs_info->dimension(0), _desc.rhs_info.n0), + rhs_info->dimension(1)); + AccessWindowStatic dst_access(dst_info, 0, 0, + dst_info->dimension(0), + dst_info->dimension(1)); + + if(bias_info != nullptr) + { + const int bias_processed_per_iteration_x = _desc.rhs_info.n0; + + AccessWindowStatic src2_access(bias_info, 0, 0, + ceil_to_multiple(bias_info->dimension(0), bias_processed_per_iteration_x), + bias_info->dimension(1)); + + window_changed = update_window_and_padding(win, src0_access, src1_access, src2_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, dst_access); // window used to update the padding requirements of dst tensor + } + else + { + window_changed = update_window_and_padding(win, src0_access, src1_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, dst_access); // window used to update the padding requirements of dst tensor + } + + // Collapse along the Z direction + // This collapse needs to be here in order to tune the Z dimension of LWS + Window collapsed = win; + const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(dst_info->num_dimensions()), 2u); + collapsed = win.collapse(win, dimension_to_collapse); + + if(window_changed == true) + { + ARM_COMPUTE_ERROR("Insufficient Padding!"); + } + + return collapsed; +} + std::string ClGemmNativeKernelComponent::get_additional_macros() const { return R"_( |