diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp | 5 | ||||
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp | 7 |
2 files changed, 9 insertions, 3 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp index 996bf15d01..6c1e0fb1de 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp @@ -110,8 +110,8 @@ R"_( R"_( LOOP_UNROLLING(int, i, 0, 1, M0, { - g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{dst}}_w * {{dst}}_h) - 1); - g_dst_indirect_y[i].v += g_ind_2 * (int)({{dst}}_w * {{dst}}_h); + g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{out}}_w * {{out}}_h) - 1); + g_dst_indirect_y[i].v += g_ind_2 * (int)({{out}}_w * {{out}}_h); }) } //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- @@ -194,6 +194,7 @@ TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vt lut["lhs"] = vtable.get_variable(_lhs); lut["rhs"] = vtable.get_variable(_rhs); lut["dst"] = vtable.get_variable(_dst); + lut["out"] = vtable.get_variable(comp_group.get_dst_tensors().front()); } else { diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp index cb643a741d..0afd0e7581 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp @@ -179,7 +179,11 @@ std::string ClTemplateWriter::write_code() code += macros; } - code += write_kernel_signature(_vtable.get_variable_list(_components.get_argument_tensors())); + auto arguments = _components.get_argument_tensors(); + std::sort(arguments.begin(), arguments.end(), [](const ITensorInfo *l, const ITensorInfo *r) { + return l->id() < r->id(); + }); + code += write_kernel_signature(_vtable.get_variable_list(arguments)); code += "\n{\n\n"; @@ -190,6 +194,7 @@ std::string ClTemplateWriter::write_code() for(const auto &component_code : component_codes) { code += component_code; + code += "\n"; } code += "}\n"; |