aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLKernelWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp282
1 files changed, 170 insertions, 112 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 2db9c139b7..62e6853a7a 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -31,14 +31,15 @@
#include "ckw/types/DataType.h"
#include "ckw/types/MemoryOperation.h"
#include "ckw/types/TargetLanguage.h"
-#include "src/ITensorComponent.h"
-#include "src/TileView.h"
+
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
#include "src/cl/CLTile.h"
#include "src/cl/helpers/CLMemoryOpBufferHelper.h"
#include "src/cl/helpers/CLMemoryOpImage2dHelper.h"
#include "src/cl/helpers/ICLMemoryOpHelper.h"
+#include "src/ITensorComponent.h"
+#include "src/TileView.h"
#include "src/types/DataTypeHelpers.h"
#include <algorithm>
@@ -63,14 +64,14 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
// Create the list of arguments.
std::vector<KernelArgument> arguments;
- for(const auto &tensor : _tensors)
+ for (const auto &tensor : _tensors)
{
const auto tensor_id = tensor->info().id();
const auto storages = tensor->storages();
const auto components = tensor->components();
- for(const auto &storage : storages)
+ for (const auto &storage : storages)
{
code += cl_get_variable_storagetype_as_string(storage.type);
code += " ";
@@ -80,7 +81,7 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
arguments.emplace_back(tensor_id, storage.type);
}
- for(const auto &component : components)
+ for (const auto &component : components)
{
const auto &tile = component->tile();
const auto &tile_info = tile.info();
@@ -96,7 +97,7 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
}
}
- if(code.size() >= 2 && code[code.size() - 2] == ',' && code[code.size() - 1] == '\n')
+ if (code.size() >= 2 && code[code.size() - 2] == ',' && code[code.size() - 1] == '\n')
{
// Remove the last comma in the argument list.
code.pop_back();
@@ -127,11 +128,12 @@ void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1,
+ "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
append_code(dst_view.vector(y).str, " = ", src_prefix, src_view.vector(y).str, ";\n");
}
@@ -158,13 +160,15 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : "";
CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different.");
- CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1,
+ "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_view.vector(y).str, ");\n");
+ append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(",
+ src_view.vector(y).str, ");\n");
}
}
@@ -189,11 +193,12 @@ void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOper
const auto op_suffix = op_is_func ? ")" : "";
CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1,
+ "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
append_code(dst_view.vector(y).str, " = ", src_prefix, op_prefix, src_view.vector(y).str, op_suffix, ";\n");
}
@@ -214,27 +219,28 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match.");
- CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1,
+ "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1,
+ "RHS tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension.");
- CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1,
+ "LHS tile width must match destination or LHS is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1,
+ "RHS tile width must match destination or RHS is broadcasting in x dimension.");
- if(op == BinaryOp::MatMul_Nt_T)
+ if (op == BinaryOp::MatMul_Nt_T)
{
CKW_ASSERT(is_data_type_float(data_type));
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
- for(int32_t x = 0; x < dst_w; ++x)
+ for (int32_t x = 0; x < dst_w; ++x)
{
- for(int32_t k = 0; k < lhs_w; ++k)
+ for (int32_t k = 0; k < lhs_w; ++k)
{
- append_code(
- dst_view.scalar(x, y).str, " = fma(",
- lhs_view.scalar(k, y).str, ", ",
- rhs_view.scalar(k, x).str, ", ",
- dst_view.scalar(x, y).str, ");\n");
+ append_code(dst_view.scalar(x, y).str, " = fma(", lhs_view.scalar(k, y).str, ", ",
+ rhs_view.scalar(k, x).str, ", ", dst_view.scalar(x, y).str, ");\n");
}
}
}
@@ -258,14 +264,16 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
const std::string op_suffix = op_is_func ? ");\n" : ";\n";
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, rhs_view.vector(y).str, op_suffix);
+ append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix,
+ rhs_view.vector(y).str, op_suffix);
}
}
}
-void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+void CLKernelWriter::op_ternary(
+ const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
{
const auto dst_view = to_cl_tile_view(dst);
const auto first_view = to_cl_tile_view(first);
@@ -297,37 +305,42 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile
CKW_ASSERT_MSG(second_view.data_type() == dst_view.data_type(), "2nd source and destination type must match.");
CKW_ASSERT_MSG(third_view.data_type() == dst_view.data_type(), "3rd source and destination type must match.");
- CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1,
+ "1st tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1,
+ "2nd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1,
+ "3rd tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension.");
- CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension.");
- CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, "3rd tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(first_w == dst_w || first_w == 1,
+ "1st tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(second_w == dst_w || second_w == 1,
+ "2nd tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(third_w == dst_w || third_w == 1,
+ "3rd tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
- for(int32_t y = 0; y < dst_h; ++y)
+ for (int32_t y = 0; y < dst_h; ++y)
{
- append_code(
- dst_view.vector(y).str, " = ", op_name, "(",
- first_prefix, first_view.vector(y).str, ", ",
- second_prefix, second_view.vector(y).str, ", ",
- third_prefix, third_view.vector(y).str, ");\n");
+ append_code(dst_view.vector(y).str, " = ", op_name, "(", first_prefix, first_view.vector(y).str, ", ",
+ second_prefix, second_view.vector(y).str, ", ", third_prefix, third_view.vector(y).str, ");\n");
}
}
-void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if)
+void CLKernelWriter::op_if_generic(
+ const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if)
{
const auto lhs_view = to_cl_tile_view(lhs);
const auto rhs_view = to_cl_tile_view(rhs);
const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_view.data_type()));
- CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || op == BinaryOp::GreaterEqual || op == BinaryOp::Greater);
+ CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal ||
+ op == BinaryOp::GreaterEqual || op == BinaryOp::Greater);
CKW_ASSERT(lhs_view.is_scalar());
CKW_ASSERT(rhs_view.is_scalar());
- if(is_else_if)
+ if (is_else_if)
{
append_code("else ");
}
@@ -337,12 +350,18 @@ void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const Ti
append_code("}\n");
}
-void CLKernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+void CLKernelWriter::op_if(const TileOperand &lhs,
+ BinaryOp op,
+ const TileOperand &rhs,
+ const std::function<void()> &body)
{
op_if_generic(lhs, op, rhs, body, false /* is_else_if */);
}
-void CLKernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+void CLKernelWriter::op_else_if(const TileOperand &lhs,
+ BinaryOp op,
+ const TileOperand &rhs,
+ const std::function<void()> &body)
{
op_if_generic(lhs, op, rhs, body, true /* is_else_if */);
}
@@ -354,10 +373,13 @@ void CLKernelWriter::op_else(const std::function<void()> &body)
append_code("}\n");
}
-void CLKernelWriter::op_for_loop(
- const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value,
- const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
- const std::function<void()> &body)
+void CLKernelWriter::op_for_loop(const TileOperand &var,
+ BinaryOp cond_op,
+ const TileOperand &cond_value,
+ const TileOperand &update_var,
+ AssignmentOp update_op,
+ const TileOperand &update_value,
+ const std::function<void()> &body)
{
const auto var_view = to_cl_tile_view(var);
const auto cond_value_view = to_cl_tile_view(cond_value);
@@ -373,11 +395,12 @@ void CLKernelWriter::op_for_loop(
CKW_ASSERT(update_var_view.data_type() == update_value_view.data_type());
const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_view.data_type()));
- CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater);
+ CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal ||
+ cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater);
- append_code(
- "for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ",
- update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_view.scalar(0, 0).str, ")\n{\n");
+ append_code("for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ",
+ update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ",
+ update_value_view.scalar(0, 0).str, ")\n{\n");
write_body(body);
append_code("}\n");
}
@@ -404,7 +427,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
std::string format_code;
std::string args_code;
- for(auto &op : operands)
+ for (auto &op : operands)
{
const auto tile_view = to_cl_tile_view(op);
@@ -416,12 +439,12 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
// Construct the format specifier to print out one row of the tile.
std::string row_format("%");
- if(width > 1)
+ if (width > 1)
{
row_format += "v" + std::to_string(width);
}
- switch(data_type)
+ switch (data_type)
{
case DataType::Fp32:
row_format += "hlg";
@@ -452,7 +475,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
CKW_THROW_MSG("Unsupported data type!");
}
- if(width > 1)
+ if (width > 1)
{
row_format = "[" + row_format + "]";
}
@@ -460,14 +483,14 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
// Construct the format specifier for the printf statement.
format_code += name + " = ";
- if(height == 1)
+ if (height == 1)
{
format_code += row_format;
}
else
{
format_code += "[" + row_format;
- for(int32_t row = 1; row < height; ++row)
+ for (int32_t row = 1; row < height; ++row)
{
format_code += ", " + row_format;
}
@@ -477,7 +500,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
format_code += "\\n";
// Construct the variable arguments for the printf statement.
- for(int32_t row = 0; row < height; ++row)
+ for (int32_t row = 0; row < height; ++row)
{
args_code += ", " + tile_view.vector(row).str;
}
@@ -527,19 +550,14 @@ TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo
const int32_t width = tile_info.width();
const DataType data_type = tile_info.data_type();
- CKW_ASSERT_MSG(
- std::find_if(
- _tiles.begin(), _tiles.end(),
- [=](const std::unique_ptr<CLTile> &e)
- {
- return e->name() == fullname;
- })
- == _tiles.end(),
- "There is already a tile with name: " + fullname);
+ CKW_ASSERT_MSG(std::find_if(_tiles.begin(), _tiles.end(),
+ [=](const std::unique_ptr<CLTile> &e)
+ { return e->name() == fullname; }) == _tiles.end(),
+ "There is already a tile with name: " + fullname);
auto tile = std::make_unique<CLTile>(fullname, tile_info);
- for(int32_t row = 0; row < height; ++row)
+ for (int32_t row = 0; row < height; ++row)
{
const std::string cl_type = cl_get_variable_datatype_as_string(data_type, width);
append_code(cl_type, " ", tile->vector(row).str, ";\n");
@@ -578,40 +596,40 @@ TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) con
{
bool found = false;
- for(const auto &t : _tiles)
+ for (const auto &t : _tiles)
{
- if(&tile == t.get())
+ if (&tile == t.get())
{
found = true;
break;
}
}
- for(const auto &t : _constant_tiles)
+ for (const auto &t : _constant_tiles)
{
- if(&tile == t.get())
+ if (&tile == t.get())
{
found = true;
break;
}
}
- if(!found)
+ if (!found)
{
- for(const auto &t : _tensors)
+ for (const auto &t : _tensors)
{
const auto components = t->components();
- for(const auto component : components)
+ for (const auto component : components)
{
- if(&tile == &component->tile())
+ if (&tile == &component->tile())
{
found = true;
break;
}
}
- if(found)
+ if (found)
{
break;
}
@@ -622,66 +640,106 @@ TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) con
}
#endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
- return { static_cast<CLTile &>(tile), area };
+ return {static_cast<CLTile &>(tile), area};
}
-void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+void CLKernelWriter::op_load(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch)
{
- const CLTile dilation_x({ { "1" } }, DataType::Int32);
- const CLTile dilation_y({ { "1" } }, DataType::Int32);
+ const CLTile dilation_x({{"1"}}, DataType::Int32);
+ const CLTile dilation_y({{"1"}}, DataType::Int32);
- op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y,
+ false /* indirect buffer */);
}
-void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y)
+void CLKernelWriter::op_load_dilated(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileOperand &dilation_x,
+ const TileOperand &dilation_y)
{
const auto dil_x_view = to_cl_tile_view(dilation_x);
const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view,
+ false /* indirect buffer */);
}
-void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+void CLKernelWriter::op_store(const TensorOperand &tensor_op,
+ const TileOperand &tile_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch)
{
- const CLTile dilation_x({ { "1" } }, DataType::Int32);
- const CLTile dilation_y({ { "1" } }, DataType::Int32);
+ const CLTile dilation_x({{"1"}}, DataType::Int32);
+ const CLTile dilation_y({{"1"}}, DataType::Int32);
- op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y,
+ false /* indirect buffer */);
}
-void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y)
+void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op,
+ const TileOperand &tile_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileOperand &dilation_x,
+ const TileOperand &dilation_y)
{
const auto dil_x_view = to_cl_tile_view(dilation_x);
const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view,
+ false /* indirect buffer */);
}
-void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+void CLKernelWriter::op_load_indirect(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch)
{
- const CLTile dilation_x({ { "1" } }, DataType::Int32);
- const CLTile dilation_y({ { "1" } }, DataType::Int32);
+ const CLTile dilation_x({{"1"}}, DataType::Int32);
+ const CLTile dilation_y({{"1"}}, DataType::Int32);
- op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y, true /* indirect buffer */);
+ op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dilation_x, dilation_y,
+ true /* indirect buffer */);
}
-void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer)
+void CLKernelWriter::op_load_store(MemoryOperation op,
+ const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileView<CLTile> &dilation_x,
+ const TileView<CLTile> &dilation_y,
+ bool indirect_buffer)
{
CKW_UNUSED(dilation_x);
CKW_ASSERT(dilation_x.is_scalar());
CKW_ASSERT(dilation_y.is_scalar());
CKW_ASSERT(dilation_x.scalar(0, 0).str == "((int)(1))"); // Dilation in x dimension is not implemented yet
- if(indirect_buffer)
+ if (indirect_buffer)
{
CKW_ASSERT(dilation_y.scalar(0, 0).str == "((int)(1))" && dilation_x.scalar(0, 0).str == "((int)(1))");
}
@@ -689,7 +747,7 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
ITensor &tensor = get_tensor(tensor_op);
std::unique_ptr<ICLMemoryOpHelper> helper;
- switch(sampler.storage())
+ switch (sampler.storage())
{
case TensorStorageType::BufferUint8Ptr:
helper = std::make_unique<CLMemoryOpBufferHelper>(this, &tensor, &sampler, op);
@@ -717,13 +775,13 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
helper->initialize(&tile, &x_tile, &z_tile, &batch_tile);
- for(int row = 0; row < tile.info().height(); ++row)
+ for (int row = 0; row < tile.info().height(); ++row)
{
- if(!indirect_buffer)
+ if (!indirect_buffer)
{
std::string coord_y = y_tile.scalar(0, 0).str + " + " + std::to_string(row);
- if(dilation_y.scalar(0, 0).str != "((int)(1))")
+ if (dilation_y.scalar(0, 0).str != "((int)(1))")
{
coord_y += " * " + dilation_y.scalar(0, 0).str;
}