diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.cpp | 282 |
1 files changed, 170 insertions, 112 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp index 2db9c139b7..62e6853a7a 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp +++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp @@ -31,14 +31,15 @@ #include "ckw/types/DataType.h" #include "ckw/types/MemoryOperation.h" #include "ckw/types/TargetLanguage.h" -#include "src/ITensorComponent.h" -#include "src/TileView.h" + #include "src/cl/CLHelpers.h" #include "src/cl/CLTensorArgument.h" #include "src/cl/CLTile.h" #include "src/cl/helpers/CLMemoryOpBufferHelper.h" #include "src/cl/helpers/CLMemoryOpImage2dHelper.h" #include "src/cl/helpers/ICLMemoryOpHelper.h" +#include "src/ITensorComponent.h" +#include "src/TileView.h" #include "src/types/DataTypeHelpers.h" #include <algorithm> @@ -63,14 +64,14 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name) // Create the list of arguments. std::vector<KernelArgument> arguments; - for(const auto &tensor : _tensors) + for (const auto &tensor : _tensors) { const auto tensor_id = tensor->info().id(); const auto storages = tensor->storages(); const auto components = tensor->components(); - for(const auto &storage : storages) + for (const auto &storage : storages) { code += cl_get_variable_storagetype_as_string(storage.type); code += " "; @@ -80,7 +81,7 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name) arguments.emplace_back(tensor_id, storage.type); } - for(const auto &component : components) + for (const auto &component : components) { const auto &tile = component->tile(); const auto &tile_info = tile.info(); @@ -96,7 +97,7 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name) } } - if(code.size() >= 2 && code[code.size() - 2] == ',' && code[code.size() - 1] == '\n') + if (code.size() >= 2 && code[code.size() - 2] == ',' && code[code.size() - 1] == '\n') { // Remove the last comma in the argument list. code.pop_back(); @@ -127,11 +128,12 @@ void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src) const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : ""; CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match."); - CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, + "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { append_code(dst_view.vector(y).str, " = ", src_prefix, src_view.vector(y).str, ";\n"); } @@ -158,13 +160,15 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : ""; CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different."); - CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, + "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { - append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_view.vector(y).str, ");\n"); + append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", + src_view.vector(y).str, ");\n"); } } @@ -189,11 +193,12 @@ void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOper const auto op_suffix = op_is_func ? ")" : ""; CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match."); - CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, + "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { append_code(dst_view.vector(y).str, " = ", src_prefix, op_prefix, src_view.vector(y).str, op_suffix, ";\n"); } @@ -214,27 +219,28 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match."); - CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, "LHS tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, "RHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, + "LHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, + "RHS tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension."); - CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension."); + CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, + "LHS tile width must match destination or LHS is broadcasting in x dimension."); + CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, + "RHS tile width must match destination or RHS is broadcasting in x dimension."); - if(op == BinaryOp::MatMul_Nt_T) + if (op == BinaryOp::MatMul_Nt_T) { CKW_ASSERT(is_data_type_float(data_type)); - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { - for(int32_t x = 0; x < dst_w; ++x) + for (int32_t x = 0; x < dst_w; ++x) { - for(int32_t k = 0; k < lhs_w; ++k) + for (int32_t k = 0; k < lhs_w; ++k) { - append_code( - dst_view.scalar(x, y).str, " = fma(", - lhs_view.scalar(k, y).str, ", ", - rhs_view.scalar(k, x).str, ", ", - dst_view.scalar(x, y).str, ");\n"); + append_code(dst_view.scalar(x, y).str, " = fma(", lhs_view.scalar(k, y).str, ", ", + rhs_view.scalar(k, x).str, ", ", dst_view.scalar(x, y).str, ");\n"); } } } @@ -258,14 +264,16 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp const std::string op_suffix = op_is_func ? ");\n" : ";\n"; // Broadcasting on y dimension is automatic (see CLTile::vector). - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { - append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, rhs_view.vector(y).str, op_suffix); + append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, + rhs_view.vector(y).str, op_suffix); } } } -void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) +void CLKernelWriter::op_ternary( + const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) { const auto dst_view = to_cl_tile_view(dst); const auto first_view = to_cl_tile_view(first); @@ -297,37 +305,42 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile CKW_ASSERT_MSG(second_view.data_type() == dst_view.data_type(), "2nd source and destination type must match."); CKW_ASSERT_MSG(third_view.data_type() == dst_view.data_type(), "3rd source and destination type must match."); - CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, "1st tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, "2nd tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, "3rd tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, + "1st tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, + "2nd tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, + "3rd tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension."); - CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension."); - CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, "3rd tile width must match or source is broadcasting in x dimension."); + CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, + "1st tile width must match or source is broadcasting in x dimension."); + CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, + "2nd tile width must match or source is broadcasting in x dimension."); + CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, + "3rd tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). - for(int32_t y = 0; y < dst_h; ++y) + for (int32_t y = 0; y < dst_h; ++y) { - append_code( - dst_view.vector(y).str, " = ", op_name, "(", - first_prefix, first_view.vector(y).str, ", ", - second_prefix, second_view.vector(y).str, ", ", - third_prefix, third_view.vector(y).str, ");\n"); + append_code(dst_view.vector(y).str, " = ", op_name, "(", first_prefix, first_view.vector(y).str, ", ", + second_prefix, second_view.vector(y).str, ", ", third_prefix, third_view.vector(y).str, ");\n"); } } -void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if) +void CLKernelWriter::op_if_generic( + const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if) { const auto lhs_view = to_cl_tile_view(lhs); const auto rhs_view = to_cl_tile_view(rhs); const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_view.data_type())); - CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || op == BinaryOp::GreaterEqual || op == BinaryOp::Greater); + CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || + op == BinaryOp::GreaterEqual || op == BinaryOp::Greater); CKW_ASSERT(lhs_view.is_scalar()); CKW_ASSERT(rhs_view.is_scalar()); - if(is_else_if) + if (is_else_if) { append_code("else "); } @@ -337,12 +350,18 @@ void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const Ti append_code("}\n"); } -void CLKernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +void CLKernelWriter::op_if(const TileOperand &lhs, + BinaryOp op, + const TileOperand &rhs, + const std::function<void()> &body) { op_if_generic(lhs, op, rhs, body, false /* is_else_if */); } -void CLKernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +void CLKernelWriter::op_else_if(const TileOperand &lhs, + BinaryOp op, + const TileOperand &rhs, + const std::function<void()> &body) { op_if_generic(lhs, op, rhs, body, true /* is_else_if */); } @@ -354,10 +373,13 @@ void CLKernelWriter::op_else(const std::function<void()> &body) append_code("}\n"); } -void CLKernelWriter::op_for_loop( - const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value, - const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value, - const std::function<void()> &body) +void CLKernelWriter::op_for_loop(const TileOperand &var, + BinaryOp cond_op, + const TileOperand &cond_value, + const TileOperand &update_var, + AssignmentOp update_op, + const TileOperand &update_value, + const std::function<void()> &body) { const auto var_view = to_cl_tile_view(var); const auto cond_value_view = to_cl_tile_view(cond_value); @@ -373,11 +395,12 @@ void CLKernelWriter::op_for_loop( CKW_ASSERT(update_var_view.data_type() == update_value_view.data_type()); const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_view.data_type())); - CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater); + CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || + cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater); - append_code( - "for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ", - update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_view.scalar(0, 0).str, ")\n{\n"); + append_code("for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ", + update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", + update_value_view.scalar(0, 0).str, ")\n{\n"); write_body(body); append_code("}\n"); } @@ -404,7 +427,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO std::string format_code; std::string args_code; - for(auto &op : operands) + for (auto &op : operands) { const auto tile_view = to_cl_tile_view(op); @@ -416,12 +439,12 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO // Construct the format specifier to print out one row of the tile. std::string row_format("%"); - if(width > 1) + if (width > 1) { row_format += "v" + std::to_string(width); } - switch(data_type) + switch (data_type) { case DataType::Fp32: row_format += "hlg"; @@ -452,7 +475,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO CKW_THROW_MSG("Unsupported data type!"); } - if(width > 1) + if (width > 1) { row_format = "[" + row_format + "]"; } @@ -460,14 +483,14 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO // Construct the format specifier for the printf statement. format_code += name + " = "; - if(height == 1) + if (height == 1) { format_code += row_format; } else { format_code += "[" + row_format; - for(int32_t row = 1; row < height; ++row) + for (int32_t row = 1; row < height; ++row) { format_code += ", " + row_format; } @@ -477,7 +500,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO format_code += "\\n"; // Construct the variable arguments for the printf statement. - for(int32_t row = 0; row < height; ++row) + for (int32_t row = 0; row < height; ++row) { args_code += ", " + tile_view.vector(row).str; } @@ -527,19 +550,14 @@ TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo const int32_t width = tile_info.width(); const DataType data_type = tile_info.data_type(); - CKW_ASSERT_MSG( - std::find_if( - _tiles.begin(), _tiles.end(), - [=](const std::unique_ptr<CLTile> &e) - { - return e->name() == fullname; - }) - == _tiles.end(), - "There is already a tile with name: " + fullname); + CKW_ASSERT_MSG(std::find_if(_tiles.begin(), _tiles.end(), + [=](const std::unique_ptr<CLTile> &e) + { return e->name() == fullname; }) == _tiles.end(), + "There is already a tile with name: " + fullname); auto tile = std::make_unique<CLTile>(fullname, tile_info); - for(int32_t row = 0; row < height; ++row) + for (int32_t row = 0; row < height; ++row) { const std::string cl_type = cl_get_variable_datatype_as_string(data_type, width); append_code(cl_type, " ", tile->vector(row).str, ";\n"); @@ -578,40 +596,40 @@ TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) con { bool found = false; - for(const auto &t : _tiles) + for (const auto &t : _tiles) { - if(&tile == t.get()) + if (&tile == t.get()) { found = true; break; } } - for(const auto &t : _constant_tiles) + for (const auto &t : _constant_tiles) { - if(&tile == t.get()) + if (&tile == t.get()) { found = true; break; } } - if(!found) + if (!found) { - for(const auto &t : _tensors) + for (const auto &t : _tensors) { const auto components = t->components(); - for(const auto component : components) + for (const auto component : components) { - if(&tile == &component->tile()) + if (&tile == &component->tile()) { found = true; break; } } - if(found) + if (found) { break; } @@ -622,66 +640,106 @@ TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) con } #endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED - return { static_cast<CLTile &>(tile), area }; + return {static_cast<CLTile &>(tile), area}; } -void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) +void CLKernelWriter::op_load(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) { - const CLTile dilation_x({ { "1" } }, DataType::Int32); - const CLTile dilation_y({ { "1" } }, DataType::Int32); + const CLTile dilation_x({{"1"}}, DataType::Int32); + const CLTile dilation_y({{"1"}}, DataType::Int32); - op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, false /* indirect buffer */); + op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, + false /* indirect buffer */); } -void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) +void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) { const auto dil_x_view = to_cl_tile_view(dilation_x); const auto dil_y_view = to_cl_tile_view(dilation_y); - op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */); + op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, + false /* indirect buffer */); } -void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) +void CLKernelWriter::op_store(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) { - const CLTile dilation_x({ { "1" } }, DataType::Int32); - const CLTile dilation_y({ { "1" } }, DataType::Int32); + const CLTile dilation_x({{"1"}}, DataType::Int32); + const CLTile dilation_y({{"1"}}, DataType::Int32); - op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, false /* indirect buffer */); + op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, + false /* indirect buffer */); } -void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) +void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) { const auto dil_x_view = to_cl_tile_view(dilation_x); const auto dil_y_view = to_cl_tile_view(dilation_y); - op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */); + op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, + false /* indirect buffer */); } -void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) +void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) { - const CLTile dilation_x({ { "1" } }, DataType::Int32); - const CLTile dilation_y({ { "1" } }, DataType::Int32); + const CLTile dilation_x({{"1"}}, DataType::Int32); + const CLTile dilation_y({{"1"}}, DataType::Int32); - op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, true /* indirect buffer */); + op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, + true /* indirect buffer */); } -void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer) +void CLKernelWriter::op_load_store(MemoryOperation op, + const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileView<CLTile> &dilation_x, + const TileView<CLTile> &dilation_y, + bool indirect_buffer) { CKW_UNUSED(dilation_x); CKW_ASSERT(dilation_x.is_scalar()); CKW_ASSERT(dilation_y.is_scalar()); CKW_ASSERT(dilation_x.scalar(0, 0).str == "((int)(1))"); // Dilation in x dimension is not implemented yet - if(indirect_buffer) + if (indirect_buffer) { CKW_ASSERT(dilation_y.scalar(0, 0).str == "((int)(1))" && dilation_x.scalar(0, 0).str == "((int)(1))"); } @@ -689,7 +747,7 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o ITensor &tensor = get_tensor(tensor_op); std::unique_ptr<ICLMemoryOpHelper> helper; - switch(sampler.storage()) + switch (sampler.storage()) { case TensorStorageType::BufferUint8Ptr: helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op); @@ -717,13 +775,13 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o helper->initialize(&tile, &x_tile, &z_tile, &batch_tile); - for(int row = 0; row < tile.info().height(); ++row) + for (int row = 0; row < tile.info().height(); ++row) { - if(!indirect_buffer) + if (!indirect_buffer) { std::string coord_y = y_tile.scalar(0, 0).str + " + " + std::to_string(row); - if(dilation_y.scalar(0, 0).str != "((int)(1))") + if (dilation_y.scalar(0, 0).str != "((int)(1))") { coord_y += " * " + dilation_y.scalar(0, 0).str; } |