diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp index 6ab3a68bb0..dcb43f9783 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp @@ -54,20 +54,26 @@ std::string ClTemplateCast::get_component_code(const ComponentGroup &comp_group) ARM_COMPUTE_UNUSED(comp_group); const std::string kernel_name = get_name(); + const auto is_root = (comp_group.get_root_component()->id() == this->id()); std::string code = R"_( -//------------------ START KERNEL {{meta_kernel_id}} --------------------- +//------------------ START KERNEL {{meta_kernel_id}} CAST --------------------- +)_"; + + if(is_root) + { + code += R"_( // IN_0(src) {{src}} // OUT(dst, accum) {{dst}} -TILE({{DATA_TYPE_OUT}}, M0, N0, {{dst}}); TILE(uint, M0, 1, g_dst_indirect_y); { {{src}}_offset_first_element_in_bytes += get_global_id(2) * {{src}}_stride_z; - TILE({{DATA_TYPE_IN}}, M0, N0, in_data); - T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, in_data); + TILE({{DATA_TYPE_IN}}, M0, N0, {{tmp}}); + T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{tmp}}); )_"; + } code += R"_( LOOP_UNROLLING(int, m0, 0, 1, M0, @@ -77,20 +83,20 @@ TILE(uint, M0, 1, g_dst_indirect_y); if(kernel_name == "cast_down" && is_data_type_quantized(_src->data_type())) { code += R"_( - in_data[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80; + {{tmp}}[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80; )_"; } if(kernel_name == "cast_down" && (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE)) { code += R"_( - {{dst}}[m0].v = CONVERT_SAT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0)); + {{dst}}[m0].v = CONVERT_SAT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0)); )_"; } else { code += R"_( - {{dst}}[m0].v = CONVERT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0)); + {{dst}}[m0].v = CONVERT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0)); )_"; } @@ -98,7 +104,9 @@ TILE(uint, M0, 1, g_dst_indirect_y); }) )_"; - code += R"_( + if(is_root) + { + code += R"_( LOOP_UNROLLING(int, i, 0, 1, M0, { g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1); @@ -106,7 +114,11 @@ TILE(uint, M0, 1, g_dst_indirect_y); g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h); }) } -//------------------ END KERNEL {{meta_kernel_id}} --------------------- +)_"; + } + + code += R"_( +//------------------ END KERNEL {{meta_kernel_id}} CAST --------------------- )_"; return code; @@ -115,27 +127,28 @@ TILE(uint, M0, 1, g_dst_indirect_y); void ClTemplateCast::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const { vtable.declare_variable( + comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), - comp_group.is_intermediate_tensor(_src), "src"); vtable.declare_variable( + comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), - comp_group.is_intermediate_tensor(_dst), "dst"); } TagLUT ClTemplateCast::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const { - ARM_COMPUTE_UNUSED(comp_group); + const auto is_root = (comp_group.get_root_component()->id() == this->id()); TagLUT lut{}; // Arguments and global shared variables lut["src"] = vtable.get_variable(_src); lut["dst"] = vtable.get_variable(_dst); + lut["tmp"] = (is_root) ? lut["src"].value + "_in_data" : lut["src"]; const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor()); lut["arg_dst"] = dst_argument.uniq_name; |