aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLKernelWriter.cpp
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-08-22 11:11:23 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-08-24 09:48:58 +0000
commit34b6c3a08c3fd3f99cf675921a319b8678a98273 (patch)
treeb52068faf874063b79f0e4ddd7d587e785bb65bf /compute_kernel_writer/src/cl/CLKernelWriter.cpp
parent3a9ecdfdc76abd7f9acdab42a1f7e4c0188d6f48 (diff)
downloadComputeLibrary-34b6c3a08c3fd3f99cf675921a319b8678a98273.tar.gz
Add CKW binary and ternary statements
Resolves: COMPMID-6388 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ia0cd1486f368af54053066f489cac83b9de01789 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10182 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp119
1 files changed, 118 insertions, 1 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 33d16da926..312162f498 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -166,7 +166,7 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
}
}
-void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op)
+void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
const auto &dst_tile = to_cl_tile(dst);
const auto &src_tile = to_cl_tile(src);
@@ -198,6 +198,123 @@ void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, Un
}
}
+void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &lhs_tile = to_cl_tile(first);
+ const auto &rhs_tile = to_cl_tile(second);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto lhs_w = lhs_tile.info().width();
+ const auto rhs_w = rhs_tile.info().width();
+
+ const auto data_type = lhs_tile.info().data_type();
+
+ CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match.");
+
+ CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
+
+ CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension.");
+
+ if(op == BinaryOp::MatMul_Nt_T)
+ {
+ CKW_ASSERT(is_data_type_float(data_type));
+
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ for(int32_t x = 0; x < dst_w; ++x)
+ {
+ for(int32_t k = 0; k < lhs_w; ++k)
+ {
+ append_code(
+ dst_tile.scalar(x, y).str, " = fma(",
+ lhs_tile.scalar(k, y).str, ", ",
+ rhs_tile.scalar(k, x).str, ", ",
+ dst_tile.scalar(x, y).str, ");\n");
+ }
+ }
+ }
+ }
+ else
+ {
+ const auto op_info = cl_get_binary_op(op, data_type);
+ const auto op_is_func = std::get<0>(op_info);
+ const auto &op_name = std::get<1>(op_info);
+
+ const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
+
+ const auto broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
+ const auto broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
+
+ const std::string lhs_prefix = broadcast_lhs_x ? "(" + data_type_str + ")" : "";
+ const std::string rhs_prefix = broadcast_rhs_x ? "(" + data_type_str + ")" : "";
+
+ const std::string op_prefix = op_is_func ? " = " + op_name + "(" : " = ";
+ const std::string op_separator = op_is_func ? ", " : " " + op_name + " ";
+ const std::string op_suffix = op_is_func ? ");\n" : ";\n";
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix);
+ }
+ }
+}
+
+void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &first_tile = to_cl_tile(first);
+ const auto &second_tile = to_cl_tile(second);
+ const auto &third_tile = to_cl_tile(third);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto first_w = first_tile.info().width();
+ const auto second_w = second_tile.info().width();
+ const auto third_w = third_tile.info().width();
+
+ const auto data_type = dst_tile.info().data_type();
+ const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
+
+ const auto op_info = cl_get_ternary_op(op);
+ const auto op_is_func = std::get<0>(op_info);
+ const auto &op_name = std::get<1>(op_info);
+
+ const auto broadcast_first_x = dst_w != 1 && first_w == 1;
+ const auto broadcast_second_x = dst_w != 1 && second_w == 1;
+ const auto broadcast_third_x = dst_w != 1 && third_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + data_type_str + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + data_type_str + ")" : "";
+ const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : "";
+
+ CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function.");
+ CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match.");
+ CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match.");
+
+ CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
+
+ CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, "3rd tile width must match or source is broadcasting in x dimension.");
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(
+ dst_tile.vector(y).str, " = ", op_name, "(",
+ first_prefix, first_tile.vector(y).str, ", ",
+ second_prefix, second_tile.vector(y).str, ", ",
+ third_prefix, third_tile.vector(y).str, ");\n");
+ }
+}
+
void CLKernelWriter::op_comment(const std::string &text)
{
#ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED