aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp37
1 files changed, 25 insertions, 12 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
index 6ab3a68bb0..dcb43f9783 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
@@ -54,20 +54,26 @@ std::string ClTemplateCast::get_component_code(const ComponentGroup &comp_group)
ARM_COMPUTE_UNUSED(comp_group);
const std::string kernel_name = get_name();
+ const auto is_root = (comp_group.get_root_component()->id() == this->id());
std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
+//------------------ START KERNEL {{meta_kernel_id}} CAST ---------------------
+)_";
+
+ if(is_root)
+ {
+ code += R"_(
// IN_0(src) {{src}}
// OUT(dst, accum) {{dst}}
-TILE({{DATA_TYPE_OUT}}, M0, N0, {{dst}});
TILE(uint, M0, 1, g_dst_indirect_y);
{
{{src}}_offset_first_element_in_bytes += get_global_id(2) * {{src}}_stride_z;
- TILE({{DATA_TYPE_IN}}, M0, N0, in_data);
- T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, in_data);
+ TILE({{DATA_TYPE_IN}}, M0, N0, {{tmp}});
+ T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{tmp}});
)_";
+ }
code += R"_(
LOOP_UNROLLING(int, m0, 0, 1, M0,
@@ -77,20 +83,20 @@ TILE(uint, M0, 1, g_dst_indirect_y);
if(kernel_name == "cast_down" && is_data_type_quantized(_src->data_type()))
{
code += R"_(
- in_data[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80;
+ {{tmp}}[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80;
)_";
}
if(kernel_name == "cast_down" && (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE))
{
code += R"_(
- {{dst}}[m0].v = CONVERT_SAT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
+ {{dst}}[m0].v = CONVERT_SAT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
)_";
}
else
{
code += R"_(
- {{dst}}[m0].v = CONVERT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
+ {{dst}}[m0].v = CONVERT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
)_";
}
@@ -98,7 +104,9 @@ TILE(uint, M0, 1, g_dst_indirect_y);
})
)_";
- code += R"_(
+ if(is_root)
+ {
+ code += R"_(
LOOP_UNROLLING(int, i, 0, 1, M0,
{
g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
@@ -106,7 +114,11 @@ TILE(uint, M0, 1, g_dst_indirect_y);
g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
})
}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
+)_";
+ }
+
+ code += R"_(
+//------------------ END KERNEL {{meta_kernel_id}} CAST ---------------------
)_";
return code;
@@ -115,27 +127,28 @@ TILE(uint, M0, 1, g_dst_indirect_y);
void ClTemplateCast::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
{
vtable.declare_variable(
+ comp_group,
_src,
GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- comp_group.is_intermediate_tensor(_src),
"src");
vtable.declare_variable(
+ comp_group,
_dst,
GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- comp_group.is_intermediate_tensor(_dst),
"dst");
}
TagLUT ClTemplateCast::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
{
- ARM_COMPUTE_UNUSED(comp_group);
+ const auto is_root = (comp_group.get_root_component()->id() == this->id());
TagLUT lut{};
// Arguments and global shared variables
lut["src"] = vtable.get_variable(_src);
lut["dst"] = vtable.get_variable(_dst);
+ lut["tmp"] = (is_root) ? lut["src"].value + "_in_data" : lut["src"];
const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
lut["arg_dst"] = dst_argument.uniq_name;