diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-08-29 16:01:13 +0100 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-08-30 12:32:10 +0000 |
commit | d0d8f2e61039826685aa076347eacce526e8c74b (patch) | |
tree | b655ee0c07c8c87ce387c370d20494169ae86d5c /compute_kernel_writer/src | |
parent | 87706692252b0746e882a9dd34ae64dc60acd767 (diff) | |
download | ComputeLibrary-d0d8f2e61039826685aa076347eacce526e8c74b.tar.gz |
Add get_global_id and printf for CKW
Resolves: COMPMID-6387
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I5bedb2fdb658a6eb5f1d5053b3840ca81cf75d03
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10214
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/src')
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.cpp | 100 | ||||
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.h | 4 |
2 files changed, 104 insertions, 0 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp index 4284388c0b..a946b989d7 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp +++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp @@ -385,6 +385,106 @@ void CLKernelWriter::op_return() append_code("return;\n"); } +void CLKernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim) +{ + const auto &tile = to_cl_tile(dst); + + CKW_ASSERT(tile.is_scalar()); + CKW_ASSERT(tile.info().data_type() == DataType::Int32 || tile.info().data_type() == DataType::Uint32); + + CKW_ASSERT(dim >= 0 && dim <= 2); + + append_code(tile.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n"); +} + +void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileOperand> &operands) +{ + std::string format_code; + std::string args_code; + + for(auto &op : operands) + { + const auto &tile = to_cl_tile(op); + const auto &info = tile.info(); + + const auto &name = tile.name(); + const auto width = info.width(); + const auto height = info.height(); + const auto data_type = info.data_type(); + + // Construct the format specifier to print out one row of the tile. + std::string row_format("%"); + + if(width > 1) + { + row_format += "v" + std::to_string(width); + } + + switch(data_type) + { + case DataType::Fp32: + row_format += "hlg"; + break; + case DataType::Fp16: + row_format += "hg"; + break; + case DataType::Int32: + case DataType::Bool: + row_format += (width > 1) ? "hli" : "i"; + break; + case DataType::Int16: + row_format += "hi"; + break; + case DataType::Int8: + row_format += "hhi"; + break; + case DataType::Uint32: + row_format += (width > 1) ? "hlu" : "u"; + break; + case DataType::Uint16: + row_format += "hu"; + break; + case DataType::Uint8: + row_format += "hhu"; + break; + default: + CKW_THROW_MSG("Unsupported data type!"); + } + + if(width > 1) + { + row_format = "[" + row_format + "]"; + } + + // Construct the format specifier for the printf statement. + format_code += name + " = "; + + if(height == 1) + { + format_code += row_format; + } + else + { + format_code += "[" + row_format; + for(int32_t row = 1; row < height; ++row) + { + format_code += ", " + row_format; + } + format_code += "]"; + } + + format_code += "\\n"; + + // Construct the variable arguments for the printf statement. + for(int32_t row = 0; row < height; ++row) + { + args_code += ", " + tile.vector(row).str; + } + } + + append_code("printf(\"", prefix, "\\n", format_code, "\"", args_code, ");\n"); +} + void CLKernelWriter::op_comment(const std::string &text) { #ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h index 9458ced916..c494847944 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.h +++ b/compute_kernel_writer/src/cl/CLKernelWriter.h @@ -95,10 +95,14 @@ public: // Misc // ============================================================================================= + void op_get_global_id(const TileOperand &dst, int32_t dim) override; + void op_comment(const std::string &text) override; void op_write_raw_code(const std::string &raw_code) override; + void op_print(const std::string &prefix, const std::vector<TileOperand> &operands) override; + // ============================================================================================= // Code generation // ============================================================================================= |