aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp100
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h4
2 files changed, 104 insertions, 0 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 4284388c0b..a946b989d7 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -385,6 +385,106 @@ void CLKernelWriter::op_return()
append_code("return;\n");
}
+void CLKernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim)
+{
+ const auto &tile = to_cl_tile(dst);
+
+ CKW_ASSERT(tile.is_scalar());
+ CKW_ASSERT(tile.info().data_type() == DataType::Int32 || tile.info().data_type() == DataType::Uint32);
+
+ CKW_ASSERT(dim >= 0 && dim <= 2);
+
+ append_code(tile.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n");
+}
+
+void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileOperand> &operands)
+{
+ std::string format_code;
+ std::string args_code;
+
+ for(auto &op : operands)
+ {
+ const auto &tile = to_cl_tile(op);
+ const auto &info = tile.info();
+
+ const auto &name = tile.name();
+ const auto width = info.width();
+ const auto height = info.height();
+ const auto data_type = info.data_type();
+
+ // Construct the format specifier to print out one row of the tile.
+ std::string row_format("%");
+
+ if(width > 1)
+ {
+ row_format += "v" + std::to_string(width);
+ }
+
+ switch(data_type)
+ {
+ case DataType::Fp32:
+ row_format += "hlg";
+ break;
+ case DataType::Fp16:
+ row_format += "hg";
+ break;
+ case DataType::Int32:
+ case DataType::Bool:
+ row_format += (width > 1) ? "hli" : "i";
+ break;
+ case DataType::Int16:
+ row_format += "hi";
+ break;
+ case DataType::Int8:
+ row_format += "hhi";
+ break;
+ case DataType::Uint32:
+ row_format += (width > 1) ? "hlu" : "u";
+ break;
+ case DataType::Uint16:
+ row_format += "hu";
+ break;
+ case DataType::Uint8:
+ row_format += "hhu";
+ break;
+ default:
+ CKW_THROW_MSG("Unsupported data type!");
+ }
+
+ if(width > 1)
+ {
+ row_format = "[" + row_format + "]";
+ }
+
+ // Construct the format specifier for the printf statement.
+ format_code += name + " = ";
+
+ if(height == 1)
+ {
+ format_code += row_format;
+ }
+ else
+ {
+ format_code += "[" + row_format;
+ for(int32_t row = 1; row < height; ++row)
+ {
+ format_code += ", " + row_format;
+ }
+ format_code += "]";
+ }
+
+ format_code += "\\n";
+
+ // Construct the variable arguments for the printf statement.
+ for(int32_t row = 0; row < height; ++row)
+ {
+ args_code += ", " + tile.vector(row).str;
+ }
+ }
+
+ append_code("printf(\"", prefix, "\\n", format_code, "\"", args_code, ");\n");
+}
+
void CLKernelWriter::op_comment(const std::string &text)
{
#ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index 9458ced916..c494847944 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -95,10 +95,14 @@ public:
// Misc
// =============================================================================================
+ void op_get_global_id(const TileOperand &dst, int32_t dim) override;
+
void op_comment(const std::string &text) override;
void op_write_raw_code(const std::string &raw_code) override;
+ void op_print(const std::string &prefix, const std::vector<TileOperand> &operands) override;
+
// =============================================================================================
// Code generation
// =============================================================================================