aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp18
1 files changed, 14 insertions, 4 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp
index 2bbea8725d..47f95b5c40 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp
@@ -88,7 +88,12 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const
T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{lhs}}, cout, mout, 1, {{lhs}}_stride_y, lhs_tile);
T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{rhs}}, cout, mout, 1, {{rhs}}_stride_y, rhs_tile);
+#if defined(IS_BROADCAST)
T_ADD_BROADCAST_X({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}});
+#else // !defined(IS_BROADCAST)
+ T_ADD({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}});
+#endif // defined(IS_BROADCAST)
+
}
//------------------ END KERNEL {{meta_kernel_id}} ELTWISE_ADD ---------------------
)_";
@@ -106,7 +111,11 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const
T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{addend}}, cout, mout, 1, {{addend}}_stride_y, addend_tile);
+#if defined(IS_BROADCAST)
T_ADD_BROADCAST_X({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}});
+#else // !defined(IS_BROADCAST)
+ T_ADD({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}});
+#endif // defined(IS_BROADCAST)
}
//------------------ END KERNEL {{meta_kernel_id}} ELTWISE_ADD ---------------------
)_";
@@ -115,16 +124,17 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const
CLBuildOptions ClElementwiseAddKernelComponent::generate_build_options() const
{
+ const auto t_src_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id);
const auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id());
CLBuildOptions build_opts{};
- const auto n0 = _blueprint->impl().get_execution_window().x().step();
- const auto m0 = _blueprint->impl().get_execution_window().y().step();
- const auto partial_m0 = t_dst_info->dimension(1) % m0;
+ const auto n0 = _blueprint->impl().get_execution_window().x().step();
+ const auto m0 = _blueprint->impl().get_execution_window().y().step();
+ const bool is_broadcast = t_src_info->tensor_shape() != t_dst_info->tensor_shape();
build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_m0));
+ build_opts.add_option_if(is_broadcast, "-DIS_BROADCAST");
return build_opts;
}