diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp | 30 |
1 files changed, 13 insertions, 17 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp index 4956879ad3..0da3a73801 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/utils/helpers/AdjustVecSize.h" #include "arm_compute/core/utils/StringUtils.h" + #include "src/core/helpers/WindowHelpers.h" #include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h" @@ -35,7 +36,7 @@ namespace experimental namespace dynamic_fusion { ClTemplateCast::ClTemplateCast(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes) - : IGpuTemplateComponentWriter{ id, tensors }, _src{}, _dst{}, _attributes{ attributes } + : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}, _attributes{attributes} { _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0); _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0); @@ -62,7 +63,7 @@ std::string ClTemplateCast::get_component_code(const ComponentGroup &comp_group) //------------------ START KERNEL {{meta_kernel_id}} CAST --------------------- )_"; - if(is_root) + if (is_root) { code += R"_( // IN_0(src) {{src}} @@ -82,14 +83,15 @@ TILE(uint, M0, 1, g_dst_indirect_y); { )_"; - if(kernel_name == "cast_down" && is_data_type_quantized(_src->data_type())) + if (kernel_name == "cast_down" && is_data_type_quantized(_src->data_type())) { code += R"_( {{tmp}}[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80; )_"; } - if(kernel_name == "cast_down" && (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE)) + if (kernel_name == "cast_down" && + (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE)) { code += R"_( {{dst}}[m0].v = CONVERT_SAT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0)); @@ -106,7 +108,7 @@ TILE(uint, M0, 1, g_dst_indirect_y); }) )_"; - if(is_root) + if (is_root) { code += R"_( LOOP_UNROLLING(int, i, 0, 1, M0, @@ -128,17 +130,11 @@ TILE(uint, M0, 1, g_dst_indirect_y); void ClTemplateCast::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const { - vtable.declare_variable( - comp_group, - _src, - GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), - "src"); - - vtable.declare_variable( - comp_group, - _dst, - GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), - "dst"); + vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), + "src"); + + vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer), + "dst"); } TagLUT ClTemplateCast::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const @@ -199,7 +195,7 @@ std::string ClTemplateCast::get_config_id() const std::set<std::string> ClTemplateCast::get_headers_list() const { - return std::set<std::string>{ "helpers.h", "tile_helpers.h" }; + return std::set<std::string>{"helpers.h", "tile_helpers.h"}; } Window ClTemplateCast::get_window() const |