From 34b6c3a08c3fd3f99cf675921a319b8678a98273 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 22 Aug 2023 11:11:23 +0100 Subject: Add CKW binary and ternary statements Resolves: COMPMID-6388 Signed-off-by: Viet-Hoa Do Change-Id: Ia0cd1486f368af54053066f489cac83b9de01789 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10182 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- compute_kernel_writer/src/cl/CLHelpers.cpp | 71 ++++++++++++++ compute_kernel_writer/src/cl/CLHelpers.h | 27 ++++++ compute_kernel_writer/src/cl/CLKernelWriter.cpp | 119 +++++++++++++++++++++++- compute_kernel_writer/src/cl/CLKernelWriter.h | 6 +- 4 files changed, 221 insertions(+), 2 deletions(-) (limited to 'compute_kernel_writer/src/cl') diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp index f62e1c28e6..e12e5e1b13 100644 --- a/compute_kernel_writer/src/cl/CLHelpers.cpp +++ b/compute_kernel_writer/src/cl/CLHelpers.cpp @@ -181,6 +181,77 @@ std::tuple cl_get_unary_op(UnaryOp op) } } +std::tuple cl_get_binary_op(BinaryOp op, DataType data_type) +{ + const auto is_float = is_data_type_float(data_type); + + switch(op) + { + case BinaryOp::Add: + return { false, "+" }; + + case BinaryOp::Sub: + return { false, "-" }; + + case BinaryOp::Mul: + return { false, "*" }; + + case BinaryOp::Div: + return { false, "/" }; + + case BinaryOp::Mod: + return { false, "%" }; + + case BinaryOp::Equal: + return { false, "==" }; + + case BinaryOp::Less: + return { false, "<" }; + + case BinaryOp::LessEqual: + return { false, "<=" }; + + case BinaryOp::Greater: + return { false, ">" }; + + case BinaryOp::GreaterEqual: + return { false, ">=" }; + + case BinaryOp::LogicalAnd: + return { false, "&&" }; + + case BinaryOp::LogicalOr: + return { false, "||" }; + + case BinaryOp::BitwiseXOR: + return { false, "^" }; + + case BinaryOp::Min: + return { true, is_float ? "fmin" : "min" }; + + case BinaryOp::Max: + return { true, is_float ? "fmax" : "max" }; + + default: + CKW_THROW_MSG("Unsupported binary operator/function!"); + } +} + +std::tuple cl_get_ternary_op(TernaryOp op) +{ + switch(op) + { + case TernaryOp::Select: + return { true, "select" }; + + case TernaryOp::Clamp: + return { true, "clamp" }; + + default: + CKW_THROW_MSG("Unsupported ternary function!"); + } +} + std::string cl_data_type_rounded_up_to_valid_vector_width(DataType dt, int32_t width) { std::string data_type; diff --git a/compute_kernel_writer/src/cl/CLHelpers.h b/compute_kernel_writer/src/cl/CLHelpers.h index 3c1a7724e2..370ffc700c 100644 --- a/compute_kernel_writer/src/cl/CLHelpers.h +++ b/compute_kernel_writer/src/cl/CLHelpers.h @@ -70,9 +70,36 @@ std::string cl_get_assignment_op_as_string(AssignmentOp op); * - str: the function name or the operator in OpenCL language. * * @param[in] op The unary operator. + * + * @return The information about the unary operation. */ std::tuple cl_get_unary_op(UnaryOp op); +/** Return the information about the binary operation. + * + * The result contains: + * - is_func: true if it's a function and false if it's an binary operator in OpenCL language. + * - str: the function name or the operator in OpenCL language. + * + * @param[in] op The binary operator. + * @param[in] data_type The input data type. + * + * @return The information about the binary operation. + */ +std::tuple cl_get_binary_op(BinaryOp op, DataType data_type); + +/** Return the information about the ternary operation. + * + * The result contains: + * - is_func: true if it's a function and false if it's a ternary operator in OpenCL language. + * - str: the function name or the operator in OpenCL language. + * + * @param[in] op The ternary operator. + * + * @return The information about the ternary operation. + */ +std::tuple cl_get_ternary_op(TernaryOp op); + /** Helper function to return the OpenCL vector size that accommodate the the desired width * * @param[in] width The desired width diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp index 33d16da926..312162f498 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp +++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp @@ -166,7 +166,7 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con } } -void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) +void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) { const auto &dst_tile = to_cl_tile(dst); const auto &src_tile = to_cl_tile(src); @@ -198,6 +198,123 @@ void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, Un } } +void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) +{ + const auto &dst_tile = to_cl_tile(dst); + const auto &lhs_tile = to_cl_tile(first); + const auto &rhs_tile = to_cl_tile(second); + + const auto dst_w = dst_tile.info().width(); + const auto dst_h = dst_tile.info().height(); + const auto lhs_w = lhs_tile.info().width(); + const auto rhs_w = rhs_tile.info().width(); + + const auto data_type = lhs_tile.info().data_type(); + + CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match."); + + CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension."); + + CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension."); + CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension."); + + if(op == BinaryOp::MatMul_Nt_T) + { + CKW_ASSERT(is_data_type_float(data_type)); + + for(int32_t y = 0; y < dst_h; ++y) + { + for(int32_t x = 0; x < dst_w; ++x) + { + for(int32_t k = 0; k < lhs_w; ++k) + { + append_code( + dst_tile.scalar(x, y).str, " = fma(", + lhs_tile.scalar(k, y).str, ", ", + rhs_tile.scalar(k, x).str, ", ", + dst_tile.scalar(x, y).str, ");\n"); + } + } + } + } + else + { + const auto op_info = cl_get_binary_op(op, data_type); + const auto op_is_func = std::get<0>(op_info); + const auto &op_name = std::get<1>(op_info); + + const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w); + + const auto broadcast_lhs_x = dst_w != 1 && lhs_w == 1; + const auto broadcast_rhs_x = dst_w != 1 && rhs_w == 1; + + const std::string lhs_prefix = broadcast_lhs_x ? "(" + data_type_str + ")" : ""; + const std::string rhs_prefix = broadcast_rhs_x ? "(" + data_type_str + ")" : ""; + + const std::string op_prefix = op_is_func ? " = " + op_name + "(" : " = "; + const std::string op_separator = op_is_func ? ", " : " " + op_name + " "; + const std::string op_suffix = op_is_func ? ");\n" : ";\n"; + + // Broadcasting on y dimension is automatic (see CLTile::vector). + for(int32_t y = 0; y < dst_h; ++y) + { + append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix); + } + } +} + +void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) +{ + const auto &dst_tile = to_cl_tile(dst); + const auto &first_tile = to_cl_tile(first); + const auto &second_tile = to_cl_tile(second); + const auto &third_tile = to_cl_tile(third); + + const auto dst_w = dst_tile.info().width(); + const auto dst_h = dst_tile.info().height(); + const auto first_w = first_tile.info().width(); + const auto second_w = second_tile.info().width(); + const auto third_w = third_tile.info().width(); + + const auto data_type = dst_tile.info().data_type(); + const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w); + + const auto op_info = cl_get_ternary_op(op); + const auto op_is_func = std::get<0>(op_info); + const auto &op_name = std::get<1>(op_info); + + const auto broadcast_first_x = dst_w != 1 && first_w == 1; + const auto broadcast_second_x = dst_w != 1 && second_w == 1; + const auto broadcast_third_x = dst_w != 1 && third_w == 1; + + const std::string first_prefix = broadcast_first_x ? "(" + data_type_str + ")" : ""; + const std::string second_prefix = broadcast_second_x ? "(" + data_type_str + ")" : ""; + const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : ""; + + CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function."); + CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match."); + CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match."); + + CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension."); + + CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension."); + CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension."); + CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, "3rd tile width must match or source is broadcasting in x dimension."); + + // Broadcasting on y dimension is automatic (see CLTile::vector). + for(int32_t y = 0; y < dst_h; ++y) + { + append_code( + dst_tile.vector(y).str, " = ", op_name, "(", + first_prefix, first_tile.vector(y).str, ", ", + second_prefix, second_tile.vector(y).str, ", ", + third_prefix, third_tile.vector(y).str, ");\n"); + } +} + void CLKernelWriter::op_comment(const std::string &text) { #ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h index ea455a7fdd..2a6b79c691 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.h +++ b/compute_kernel_writer/src/cl/CLKernelWriter.h @@ -64,7 +64,11 @@ public: void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) override; - void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) override; + void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) override; + + void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) override; + + void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) override; // ============================================================================================= // Misc -- cgit v1.2.1