aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/template_writer/cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl')
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp5
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp7
2 files changed, 9 insertions, 3 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
index 996bf15d01..6c1e0fb1de 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
@@ -110,8 +110,8 @@ R"_(
R"_(
LOOP_UNROLLING(int, i, 0, 1, M0,
{
- g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{dst}}_w * {{dst}}_h) - 1);
- g_dst_indirect_y[i].v += g_ind_2 * (int)({{dst}}_w * {{dst}}_h);
+ g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{out}}_w * {{out}}_h) - 1);
+ g_dst_indirect_y[i].v += g_ind_2 * (int)({{out}}_w * {{out}}_h);
})
}
//------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
@@ -194,6 +194,7 @@ TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vt
lut["lhs"] = vtable.get_variable(_lhs);
lut["rhs"] = vtable.get_variable(_rhs);
lut["dst"] = vtable.get_variable(_dst);
+ lut["out"] = vtable.get_variable(comp_group.get_dst_tensors().front());
}
else
{
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
index cb643a741d..0afd0e7581 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
@@ -179,7 +179,11 @@ std::string ClTemplateWriter::write_code()
code += macros;
}
- code += write_kernel_signature(_vtable.get_variable_list(_components.get_argument_tensors()));
+ auto arguments = _components.get_argument_tensors();
+ std::sort(arguments.begin(), arguments.end(), [](const ITensorInfo *l, const ITensorInfo *r) {
+ return l->id() < r->id();
+ });
+ code += write_kernel_signature(_vtable.get_variable_list(arguments));
code += "\n{\n\n";
@@ -190,6 +194,7 @@ std::string ClTemplateWriter::write_code()
for(const auto &component_code : component_codes)
{
code += component_code;
+ code += "\n";
}
code += "}\n";