diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp index a44b5faee2..06c29c4253 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp @@ -24,6 +24,9 @@ #if defined(ENABLE_EXPERIMENTAL_DYNAMIC_FUSION) #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h" +#include "arm_compute/core/Validate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" namespace arm_compute { @@ -41,6 +44,26 @@ std::set<std::string> ClElementwiseAddKernelComponent::get_headers_list() const return std::set<std::string> { "gemm_helpers.h", "repeat.h" }; } +Window ClElementwiseAddKernelComponent::get_window() const +{ + const ITensorInfo *lhs_info = _blueprint->impl().get_kernel_argument_info(_lhs.arg_id); + const ITensorInfo *rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); + ITensorInfo *dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs_info, rhs_info, dst_info); + + const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*lhs_info, *rhs_info); + const TensorShape &out_shape = broadcast_pair.first; + + auto_init_if_empty(*dst_info, out_shape, 1, lhs_info->data_type()); + + const unsigned int vector_size_byte_opencl = 16; + const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst_info->element_size(), dst_info->dimension(0)); + Window win = calculate_max_window(*dst_info, Steps(num_elems_processed_per_iteration)); + + return win; +} + std::string ClElementwiseAddKernelComponent::get_component_code() const { std::string code; |