diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp | 122 |
1 files changed, 35 insertions, 87 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp index df8deee44f..01017ed909 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. + * Copyright (c) 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,9 +61,7 @@ std::string ClTemplateElementwiseBinary::get_name() const std::string ClTemplateElementwiseBinary::get_component_code(const ComponentGroup &comp_group) const { - ARM_COMPUTE_UNUSED(comp_group); std::string code; - const bool is_broadcast = _lhs->tensor_shape() != _rhs->tensor_shape(); const bool is_root = (comp_group.get_root_component()->id() == this->id()); const bool is_lhs_input = comp_group.is_input_tensor(_lhs); const bool is_rhs_input = comp_group.is_input_tensor(_rhs); @@ -85,7 +83,7 @@ R"_( { code += R"_( - TILE({{DATA_TYPE}}, M0, N0, {{lhs}}); + TILE({{DATA_TYPE}}, {{lhs_m0}}, N0, {{lhs}}); )_"; } @@ -93,7 +91,7 @@ R"_( { code += R"_( - TILE({{DATA_TYPE}}, M0, N0, {{rhs}}); + TILE({{DATA_TYPE}}, {{rhs_m0}}, N0, {{rhs}}); )_"; } @@ -106,7 +104,7 @@ R"_( { code += R"_( - {{lhs}}_offset_first_element_in_bytes += g_ind_2 * {{lhs}}_stride_z; + {{lhs}}_offset_first_element_in_bytes += g_ind_2 * {{lhs}}_stride_w; T_LOAD({{DATA_TYPE}}, {{lhs_m0}}, {{lhs_n0}}, BUFFER, {{lhs}}, {{lhs_start_ind_0}}, {{lhs_start_ind_1}}, 1, {{lhs}}_stride_y, {{lhs}}); )_"; } @@ -115,25 +113,15 @@ R"_( { code += R"_( - {{rhs}}_offset_first_element_in_bytes += g_ind_2 * {{rhs}}_stride_z; + {{rhs}}_offset_first_element_in_bytes += g_ind_2 * {{rhs}}_stride_w; T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{rhs}}, {{rhs_start_ind_0}}, {{rhs_start_ind_1}}, 1, {{rhs}}_stride_y, {{rhs}}); )_"; } - if(is_broadcast) - { - code += - R"_( - T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}}); -)_"; - } - else - { - code += - R"_( - T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}}); + code += +R"_( + T_ELTWISE_{{BROADCAST_OP}}{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}}); )_"; - } if(is_root) { @@ -210,73 +198,33 @@ TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vt // Set broadcast parameters // PRE: All tensors are broadcast-compatible - if(_lhs->tensor_shape() != _dst->tensor_shape()) - { - const auto is_broadcast_x = _lhs->dimension(0) == 1U && _dst->dimension(0) != 1U; - const auto is_broadcast_y = _lhs->dimension(1) == 1U && _dst->dimension(1) != 1U; - const auto is_broadcast_z = _lhs->dimension(2) == 1U && _dst->dimension(2) != 1U; - - // Note that n0 maps to input tensor dimension 0, m0 maps to input dimensions 1 and 2 because of our collapse strategy - if(is_broadcast_x && is_broadcast_y && is_broadcast_z) // Broadcast in X, Y, Z: collapsed lhs win [M0xN0] = [1x1] - { - lut["lhs_m0"] = "1"; - lut["lhs_n0"] = "1"; - lut["lhs_start_ind_1"] = "0"; - lut["lhs_start_ind_0"] = "0"; - } - else if(is_broadcast_y && is_broadcast_z) // Broadcast in Y and Z: collapsed lhs win [M0xN0] = [1xN] - { - lut["lhs_m0"] = "1"; - lut["lhs_n0"] = "N0"; - lut["lhs_start_ind_1"] = "0"; - lut["lhs_start_ind_0"] = "g_ind_0"; - } - else - { - ARM_COMPUTE_ERROR("Only support lhs broadcasting in all X, Y, Z dimensions, or just in Y and Z dimensions"); - } - } - else - { - lut["lhs_m0"] = "M0"; - lut["lhs_n0"] = "N0"; - lut["lhs_start_ind_1"] = "g_ind_1"; - lut["lhs_start_ind_0"] = "g_ind_0"; - } - - if(_rhs->tensor_shape() != _dst->tensor_shape()) - { - const auto is_broadcast_x = _rhs->dimension(0) == 1U && _dst->dimension(0) != 1U; - const auto is_broadcast_y = _rhs->dimension(1) == 1U && _dst->dimension(1) != 1U; - const auto is_broadcast_z = _rhs->dimension(2) == 1U && _dst->dimension(2) != 1U; - - // Note that n0 maps to input tensor dimension 0, m0 maps to input dimensions 1 and 2 because of our collapse strategy - if(is_broadcast_x && is_broadcast_y && is_broadcast_z) // Broadcast in X, Y, Z: collapsed rhs win [M0xN0] = [1x1] - { - lut["rhs_m0"] = "1"; - lut["rhs_n0"] = "1"; - lut["rhs_start_ind_1"] = "0"; - lut["rhs_start_ind_0"] = "0"; - } - else if(is_broadcast_y && is_broadcast_z) // Broadcast in Y and Z: collapsed rhs win [M0xN0] = [1xN] - { - lut["rhs_m0"] = "1"; - lut["rhs_n0"] = "N0"; - lut["rhs_start_ind_1"] = "0"; - lut["rhs_start_ind_0"] = "g_ind_0"; - } - else - { - ARM_COMPUTE_ERROR("Only support rhs broadcasting in all X, Y, Z dimensions, or just in Y and Z dimensions"); - } - } - else - { - lut["rhs_m0"] = "M0"; - lut["rhs_n0"] = "N0"; - lut["rhs_start_ind_1"] = "g_ind_1"; - lut["rhs_start_ind_0"] = "g_ind_0"; - } + const auto &lhs_dims = _lhs->tensor_shape(); + const auto &rhs_dims = _rhs->tensor_shape(); + const auto &dst_dims = _dst->tensor_shape(); + + const auto lhs_broadcast_x = dst_dims[0] != 1 && lhs_dims[0] == 1; + const auto rhs_broadcast_x = dst_dims[0] != 1 && rhs_dims[0] == 1; + const auto lhs_broadcast_y = dst_dims[1] != 1 && lhs_dims[1] == 1; + const auto rhs_broadcast_y = dst_dims[1] != 1 && rhs_dims[1] == 1; + const auto lhs_broadcast_z = dst_dims[2] != 1 && lhs_dims[2] == 1; + const auto rhs_broadcast_z = dst_dims[2] != 1 && rhs_dims[2] == 1; + + const auto lhs_broadcast_yz = lhs_broadcast_y && lhs_broadcast_z; + const auto rhs_broadcast_yz = rhs_broadcast_y && rhs_broadcast_z; + + lut["lhs_n0"] = (lhs_broadcast_x) ? "1" : "N0"; + lut["lhs_start_ind_0"] = (lhs_broadcast_x) ? "0" : "g_ind_0"; + lut["rhs_n0"] = (rhs_broadcast_x) ? "1" : "N0"; + lut["rhs_start_ind_0"] = (rhs_broadcast_x) ? "0" : "g_ind_0"; + + lut["lhs_m0"] = (lhs_broadcast_yz) ? "1" : "M0"; + lut["lhs_start_ind_1"] = (lhs_broadcast_yz) ? "0" : "g_ind_1"; + lut["rhs_m0"] = (rhs_broadcast_yz) ? "1" : "M0"; + lut["rhs_start_ind_1"] = (rhs_broadcast_yz) ? "0" : "g_ind_1"; + + lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" : + (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" : + ""; return lut; } |