diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.cpp | 88 |
1 files changed, 60 insertions, 28 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp index 62e6853a7a..8b4876b6a7 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp +++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp @@ -47,6 +47,25 @@ #include <tuple> #include <vector> +namespace +{ +std::string generate_cl_extensions() +{ + std::string ext = R"( +#if defined(cl_khr_fp16) +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif // defined(cl_khr_fp16) + +#if defined(cl_arm_printf) +#pragma OPENCL EXTENSION cl_arm_printf : enable +#endif // defined(cl_arm_printf); + +#define inf (INFINITY) +)"; + return ext; +} +} // namespace + namespace ckw { @@ -56,7 +75,7 @@ CLKernelWriter::~CLKernelWriter() = default; std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name) { std::string code; - + code += generate_cl_extensions(); code += "__kernel void "; code += name; code += "\n(\n"; @@ -154,21 +173,31 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w); const std::string sat = policy == ConvertPolicy::Saturate ? "_sat" : ""; + CKW_ASSERT_IF(policy == ConvertPolicy::Saturate, !is_data_type_float(dst_type)); const auto broadcast_x = dst_w != 1 && src_w == 1; const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : ""; - CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different."); CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). - for (int32_t y = 0; y < dst_h; ++y) + if (src_view.data_type() == dst_view.data_type()) + { + for (int32_t y = 0; y < dst_h; ++y) + { + append_code(dst_view.vector(y).str, " = ", src_view.vector(y).str, ";\n"); + } + } + else { - append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", - src_view.vector(y).str, ");\n"); + for (int32_t y = 0; y < dst_h; ++y) + { + append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", + src_view.vector(y).str, ");\n"); + } } } @@ -219,18 +248,12 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match."); - CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, - "LHS tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, - "RHS tile height must match or source is broadcasting in y dimension."); - - CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, - "LHS tile width must match destination or LHS is broadcasting in x dimension."); - CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, - "RHS tile width must match destination or RHS is broadcasting in x dimension."); - if (op == BinaryOp::MatMul_Nt_T) { + CKW_ASSERT_MSG(lhs_view.height() == dst_h, "LHS tile height must match the DST tile height"); + CKW_ASSERT_MSG(rhs_view.height() == dst_w, "RHS tile height must match the DST tile width"); + CKW_ASSERT_MSG(lhs_view.width() == rhs_view.width(), "LHS tile width must match the LHS tile width"); + CKW_ASSERT(is_data_type_float(data_type)); for (int32_t y = 0; y < dst_h; ++y) @@ -239,14 +262,24 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp { for (int32_t k = 0; k < lhs_w; ++k) { - append_code(dst_view.scalar(x, y).str, " = fma(", lhs_view.scalar(k, y).str, ", ", - rhs_view.scalar(k, x).str, ", ", dst_view.scalar(x, y).str, ");\n"); + append_code(dst_view.scalar(y, x).str, " = fma(", lhs_view.scalar(y, k).str, ", ", + rhs_view.scalar(x, k).str, ", ", dst_view.scalar(y, x).str, ");\n"); } } } } else { + CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, + "LHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, + "RHS tile height must match or source is broadcasting in y dimension."); + + CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, + "LHS tile width must match destination or LHS is broadcasting in x dimension."); + CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, + "RHS tile width must match destination or RHS is broadcasting in x dimension."); + const auto op_info = cl_get_binary_op(op, data_type); const auto op_is_func = std::get<0>(op_info); const auto &op_name = std::get<1>(op_info); @@ -746,36 +779,35 @@ void CLKernelWriter::op_load_store(MemoryOperation op, ITensor &tensor = get_tensor(tensor_op); + const auto tile = to_cl_tile_view(tile_op); + const auto x_tile = to_cl_tile_view(x).full_tile(); + const auto y_tile = to_cl_tile_view(y).full_tile(); + const auto z_tile = to_cl_tile_view(z).full_tile(); + const auto batch_tile = to_cl_tile_view(batch).full_tile(); + std::unique_ptr<ICLMemoryOpHelper> helper; switch (sampler.storage()) { case TensorStorageType::BufferUint8Ptr: - helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op); + helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op, tile); break; case TensorStorageType::Texture2dReadOnly: case TensorStorageType::Texture2dWriteOnly: - helper = std::make_unique<CLMemoryOpImage2dHelper>(this, &tensor, &sampler, op); + helper = std::make_unique<CLMemoryOpImage2dHelper>(this, &tensor, &sampler, op, tile); break; default: CKW_THROW_MSG("Unsupported tensor storage"); } - // Load/store op doesn't support sub-tile access. - const auto tile = to_cl_tile_view(tile_op).full_tile(); - const auto x_tile = to_cl_tile_view(x).full_tile(); - const auto y_tile = to_cl_tile_view(y).full_tile(); - const auto z_tile = to_cl_tile_view(z).full_tile(); - const auto batch_tile = to_cl_tile_view(batch).full_tile(); - CKW_ASSERT(x_tile.is_scalar()); CKW_ASSERT(z_tile.is_scalar()); CKW_ASSERT_IF(indirect_buffer, y_tile.info().width() == 1); CKW_ASSERT_IF(!indirect_buffer, y_tile.is_scalar()); CKW_ASSERT(batch_tile.is_scalar()); - helper->initialize(&tile, &x_tile, &z_tile, &batch_tile); + helper->initialize(&x_tile, &z_tile, &batch_tile); - for (int row = 0; row < tile.info().height(); ++row) + for (int row = 0; row < tile.height(); ++row) { if (!indirect_buffer) { |