aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLKernelWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp88
1 files changed, 60 insertions, 28 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 62e6853a7a..8b4876b6a7 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -47,6 +47,25 @@
#include <tuple>
#include <vector>
+namespace
+{
+std::string generate_cl_extensions()
+{
+ std::string ext = R"(
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif // defined(cl_khr_fp16)
+
+#if defined(cl_arm_printf)
+#pragma OPENCL EXTENSION cl_arm_printf : enable
+#endif // defined(cl_arm_printf);
+
+#define inf (INFINITY)
+)";
+ return ext;
+}
+} // namespace
+
namespace ckw
{
@@ -56,7 +75,7 @@ CLKernelWriter::~CLKernelWriter() = default;
std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
{
std::string code;
-
+ code += generate_cl_extensions();
code += "__kernel void ";
code += name;
code += "\n(\n";
@@ -154,21 +173,31 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w);
const std::string sat = policy == ConvertPolicy::Saturate ? "_sat" : "";
+
CKW_ASSERT_IF(policy == ConvertPolicy::Saturate, !is_data_type_float(dst_type));
const auto broadcast_x = dst_w != 1 && src_w == 1;
const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : "";
- CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different.");
CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1,
"Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for (int32_t y = 0; y < dst_h; ++y)
+ if (src_view.data_type() == dst_view.data_type())
+ {
+ for (int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_view.vector(y).str, " = ", src_view.vector(y).str, ";\n");
+ }
+ }
+ else
{
- append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(",
- src_view.vector(y).str, ");\n");
+ for (int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(",
+ src_view.vector(y).str, ");\n");
+ }
}
}
@@ -219,18 +248,12 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match.");
- CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1,
- "LHS tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1,
- "RHS tile height must match or source is broadcasting in y dimension.");
-
- CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1,
- "LHS tile width must match destination or LHS is broadcasting in x dimension.");
- CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1,
- "RHS tile width must match destination or RHS is broadcasting in x dimension.");
-
if (op == BinaryOp::MatMul_Nt_T)
{
+ CKW_ASSERT_MSG(lhs_view.height() == dst_h, "LHS tile height must match the DST tile height");
+ CKW_ASSERT_MSG(rhs_view.height() == dst_w, "RHS tile height must match the DST tile width");
+ CKW_ASSERT_MSG(lhs_view.width() == rhs_view.width(), "LHS tile width must match the LHS tile width");
+
CKW_ASSERT(is_data_type_float(data_type));
for (int32_t y = 0; y < dst_h; ++y)
@@ -239,14 +262,24 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
{
for (int32_t k = 0; k < lhs_w; ++k)
{
- append_code(dst_view.scalar(x, y).str, " = fma(", lhs_view.scalar(k, y).str, ", ",
- rhs_view.scalar(k, x).str, ", ", dst_view.scalar(x, y).str, ");\n");
+ append_code(dst_view.scalar(y, x).str, " = fma(", lhs_view.scalar(y, k).str, ", ",
+ rhs_view.scalar(x, k).str, ", ", dst_view.scalar(y, x).str, ");\n");
}
}
}
}
else
{
+ CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1,
+ "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1,
+ "RHS tile height must match or source is broadcasting in y dimension.");
+
+ CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1,
+ "LHS tile width must match destination or LHS is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1,
+ "RHS tile width must match destination or RHS is broadcasting in x dimension.");
+
const auto op_info = cl_get_binary_op(op, data_type);
const auto op_is_func = std::get<0>(op_info);
const auto &op_name = std::get<1>(op_info);
@@ -746,36 +779,35 @@ void CLKernelWriter::op_load_store(MemoryOperation op,
ITensor &tensor = get_tensor(tensor_op);
+ const auto tile = to_cl_tile_view(tile_op);
+ const auto x_tile = to_cl_tile_view(x).full_tile();
+ const auto y_tile = to_cl_tile_view(y).full_tile();
+ const auto z_tile = to_cl_tile_view(z).full_tile();
+ const auto batch_tile = to_cl_tile_view(batch).full_tile();
+
std::unique_ptr<ICLMemoryOpHelper> helper;
switch (sampler.storage())
{
case TensorStorageType::BufferUint8Ptr:
- helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op);
+ helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op, tile);
break;
case TensorStorageType::Texture2dReadOnly:
case TensorStorageType::Texture2dWriteOnly:
- helper = std::make_unique<CLMemoryOpImage2dHelper>(this, &tensor, &sampler, op);
+ helper = std::make_unique<CLMemoryOpImage2dHelper>(this, &tensor, &sampler, op, tile);
break;
default:
CKW_THROW_MSG("Unsupported tensor storage");
}
- // Load/store op doesn't support sub-tile access.
- const auto tile = to_cl_tile_view(tile_op).full_tile();
- const auto x_tile = to_cl_tile_view(x).full_tile();
- const auto y_tile = to_cl_tile_view(y).full_tile();
- const auto z_tile = to_cl_tile_view(z).full_tile();
- const auto batch_tile = to_cl_tile_view(batch).full_tile();
-
CKW_ASSERT(x_tile.is_scalar());
CKW_ASSERT(z_tile.is_scalar());
CKW_ASSERT_IF(indirect_buffer, y_tile.info().width() == 1);
CKW_ASSERT_IF(!indirect_buffer, y_tile.is_scalar());
CKW_ASSERT(batch_tile.is_scalar());
- helper->initialize(&tile, &x_tile, &z_tile, &batch_tile);
+ helper->initialize(&x_tile, &z_tile, &batch_tile);
- for (int row = 0; row < tile.info().height(); ++row)
+ for (int row = 0; row < tile.height(); ++row)
{
if (!indirect_buffer)
{