diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp index 01017ed909..0dd7ca5e78 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp @@ -68,7 +68,7 @@ std::string ClTemplateElementwiseBinary::get_component_code(const ComponentGroup code = R"_( - //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- + //------------------ START KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} --------------------- )_"; if(is_root) @@ -139,7 +139,7 @@ R"_( code += R"_( } - //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- + //------------------ END KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} --------------------- )_"; return code; @@ -168,33 +168,34 @@ void ClTemplateElementwiseBinary::declare_variables(GpuKernelVariableTable &vtab TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const { - TagLUT lut{}; + TagLUT lut{}; // Local build options lut["meta_kernel_id"] = id(); lut["DATA_TYPE"] = get_cl_type_from_data_type(_lhs->data_type()); // Arguments and global shared variables - lut["lhs"] = vtable.get_variable(_lhs); - lut["rhs"] = vtable.get_variable(_rhs); - lut["dst"] = vtable.get_variable(_dst); + lut["lhs"] = vtable.get_variable(_lhs); + lut["rhs"] = vtable.get_variable(_rhs); + lut["dst"] = vtable.get_variable(_dst); lut["arg_dst"] = vtable.get_variable(comp_group.get_any_dst_tensor()); switch(_attributes.operation()) { - case Attributes::ElementwiseOp::ADD: + case Attributes::ElementwiseOp::Add: lut["ELTWISE_OP"] = "ADD"; break; + case Attributes::ElementwiseOp::Mul: + lut["ELTWISE_OP"] = "MUL"; + break; default: ARM_COMPUTE_ERROR("Arithmetic Operation not supported"); } ARM_COMPUTE_ERROR_ON( - comp_group.is_intermediate_tensor(_lhs) && - detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0)); + comp_group.is_intermediate_tensor(_lhs) && detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0)); ARM_COMPUTE_ERROR_ON( - comp_group.is_intermediate_tensor(_rhs) && - detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0)); + comp_group.is_intermediate_tensor(_rhs) && detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0)); // Set broadcast parameters // PRE: All tensors are broadcast-compatible @@ -222,9 +223,9 @@ TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vt lut["rhs_m0"] = (rhs_broadcast_yz) ? "1" : "M0"; lut["rhs_start_ind_1"] = (rhs_broadcast_yz) ? "0" : "g_ind_1"; - lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" : - (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" : - ""; + lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" : + (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" : + ""; return lut; } |