aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp29
1 files changed, 15 insertions, 14 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
index 01017ed909..0dd7ca5e78 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
@@ -68,7 +68,7 @@ std::string ClTemplateElementwiseBinary::get_component_code(const ComponentGroup
code =
R"_(
- //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
+ //------------------ START KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} ---------------------
)_";
if(is_root)
@@ -139,7 +139,7 @@ R"_(
code +=
R"_(
}
- //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
+ //------------------ END KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} ---------------------
)_";
return code;
@@ -168,33 +168,34 @@ void ClTemplateElementwiseBinary::declare_variables(GpuKernelVariableTable &vtab
TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
{
- TagLUT lut{};
+ TagLUT lut{};
// Local build options
lut["meta_kernel_id"] = id();
lut["DATA_TYPE"] = get_cl_type_from_data_type(_lhs->data_type());
// Arguments and global shared variables
- lut["lhs"] = vtable.get_variable(_lhs);
- lut["rhs"] = vtable.get_variable(_rhs);
- lut["dst"] = vtable.get_variable(_dst);
+ lut["lhs"] = vtable.get_variable(_lhs);
+ lut["rhs"] = vtable.get_variable(_rhs);
+ lut["dst"] = vtable.get_variable(_dst);
lut["arg_dst"] = vtable.get_variable(comp_group.get_any_dst_tensor());
switch(_attributes.operation())
{
- case Attributes::ElementwiseOp::ADD:
+ case Attributes::ElementwiseOp::Add:
lut["ELTWISE_OP"] = "ADD";
break;
+ case Attributes::ElementwiseOp::Mul:
+ lut["ELTWISE_OP"] = "MUL";
+ break;
default:
ARM_COMPUTE_ERROR("Arithmetic Operation not supported");
}
ARM_COMPUTE_ERROR_ON(
- comp_group.is_intermediate_tensor(_lhs) &&
- detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0));
+ comp_group.is_intermediate_tensor(_lhs) && detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0));
ARM_COMPUTE_ERROR_ON(
- comp_group.is_intermediate_tensor(_rhs) &&
- detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0));
+ comp_group.is_intermediate_tensor(_rhs) && detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0));
// Set broadcast parameters
// PRE: All tensors are broadcast-compatible
@@ -222,9 +223,9 @@ TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vt
lut["rhs_m0"] = (rhs_broadcast_yz) ? "1" : "M0";
lut["rhs_start_ind_1"] = (rhs_broadcast_yz) ? "0" : "g_ind_1";
- lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" :
- (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" :
- "";
+ lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" :
+ (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" :
+ "";
return lut;
}