aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLKernelWriter.cpp
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-09-19 16:41:34 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-09-22 12:07:09 +0000
commitcd1f03e765ad0f3ca3b68b1a7c1d0a1539cab439 (patch)
tree9cb78579e01e14501c316f5297c804ba13c8ad37 /compute_kernel_writer/src/cl/CLKernelWriter.cpp
parent1f841a52f9a7f52948d676bc3807461bbed6f70a (diff)
downloadComputeLibrary-cd1f03e765ad0f3ca3b68b1a7c1d0a1539cab439.tar.gz
Add row vector and scalar access support to tile operand
* Add the concept of tile view which refers to a specific rectangular area of the tile object. - The active area is added to TileOperand so that the user can access part of the tile. - Currently only row vector and scalar access are exposed to the user. - All writing operations except load/store op support sub-tile. * Add tests for sub-tile access. Resolves: COMPMID-6557 Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ica3f9eaf17f06e080c495d36c572f623b62c2910 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10354 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp229
1 files changed, 116 insertions, 113 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 4074da7912..2db9c139b7 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -32,6 +32,7 @@
#include "ckw/types/MemoryOperation.h"
#include "ckw/types/TargetLanguage.h"
#include "src/ITensorComponent.h"
+#include "src/TileView.h"
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
#include "src/cl/CLTile.h"
@@ -42,6 +43,7 @@
#include <algorithm>
#include <cstdint>
+#include <tuple>
#include <vector>
namespace ckw
@@ -112,39 +114,39 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w);
const auto broadcast_src_x = dst_w != 1 && src_w == 1;
const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", src_prefix, src_tile.vector(y).str, ";\n");
+ append_code(dst_view.vector(y).str, " = ", src_prefix, src_view.vector(y).str, ";\n");
}
}
void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto dst_type = dst_tile.info().data_type();
+ const auto dst_type = dst_view.data_type();
const auto convert_type_str = cl_get_variable_datatype_as_string(dst_type, src_w);
const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w);
@@ -155,27 +157,27 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
const auto broadcast_x = dst_w != 1 && src_w == 1;
const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() != dst_tile.info().data_type(), "Source and destination type must be different.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_tile.vector(y).str, ");\n");
+ append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_view.vector(y).str, ");\n");
}
}
void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w);
const auto broadcast_src_x = dst_w != 1 && src_w == 1;
const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
@@ -186,35 +188,34 @@ void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOper
const auto op_prefix = op_is_func ? op_name + "(" : op_name;
const auto op_suffix = op_is_func ? ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(!is_data_type_float(src_tile.info().data_type()), "Logical and bitwise not only work with integer.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", src_prefix, op_prefix, src_tile.vector(y).str, op_suffix, ";\n");
+ append_code(dst_view.vector(y).str, " = ", src_prefix, op_prefix, src_view.vector(y).str, op_suffix, ";\n");
}
}
void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &lhs_tile = to_cl_tile(first);
- const auto &rhs_tile = to_cl_tile(second);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto lhs_view = to_cl_tile_view(first);
+ const auto rhs_view = to_cl_tile_view(second);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto lhs_w = lhs_tile.info().width();
- const auto rhs_w = rhs_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto lhs_w = lhs_view.width();
+ const auto rhs_w = rhs_view.width();
- const auto data_type = lhs_tile.info().data_type();
+ const auto data_type = lhs_view.data_type();
- CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match.");
+ CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match.");
- CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension.");
CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension.");
@@ -230,10 +231,10 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
for(int32_t k = 0; k < lhs_w; ++k)
{
append_code(
- dst_tile.scalar(x, y).str, " = fma(",
- lhs_tile.scalar(k, y).str, ", ",
- rhs_tile.scalar(k, x).str, ", ",
- dst_tile.scalar(x, y).str, ");\n");
+ dst_view.scalar(x, y).str, " = fma(",
+ lhs_view.scalar(k, y).str, ", ",
+ rhs_view.scalar(k, x).str, ", ",
+ dst_view.scalar(x, y).str, ");\n");
}
}
}
@@ -259,25 +260,25 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix);
+ append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, rhs_view.vector(y).str, op_suffix);
}
}
}
void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &first_tile = to_cl_tile(first);
- const auto &second_tile = to_cl_tile(second);
- const auto &third_tile = to_cl_tile(third);
-
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto first_w = first_tile.info().width();
- const auto second_w = second_tile.info().width();
- const auto third_w = third_tile.info().width();
-
- const auto data_type = dst_tile.info().data_type();
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto first_view = to_cl_tile_view(first);
+ const auto second_view = to_cl_tile_view(second);
+ const auto third_view = to_cl_tile_view(third);
+
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto first_w = first_view.width();
+ const auto second_w = second_view.width();
+ const auto third_w = third_view.width();
+
+ const auto data_type = dst_view.data_type();
const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
const auto op_info = cl_get_ternary_op(op);
@@ -293,12 +294,12 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile
const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : "";
CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function.");
- CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match.");
- CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match.");
+ CKW_ASSERT_MSG(second_view.data_type() == dst_view.data_type(), "2nd source and destination type must match.");
+ CKW_ASSERT_MSG(third_view.data_type() == dst_view.data_type(), "3rd source and destination type must match.");
- CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension.");
CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension.");
@@ -308,30 +309,30 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile
for(int32_t y = 0; y < dst_h; ++y)
{
append_code(
- dst_tile.vector(y).str, " = ", op_name, "(",
- first_prefix, first_tile.vector(y).str, ", ",
- second_prefix, second_tile.vector(y).str, ", ",
- third_prefix, third_tile.vector(y).str, ");\n");
+ dst_view.vector(y).str, " = ", op_name, "(",
+ first_prefix, first_view.vector(y).str, ", ",
+ second_prefix, second_view.vector(y).str, ", ",
+ third_prefix, third_view.vector(y).str, ");\n");
}
}
void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if)
{
- const auto &lhs_tile = to_cl_tile(lhs);
- const auto &rhs_tile = to_cl_tile(rhs);
+ const auto lhs_view = to_cl_tile_view(lhs);
+ const auto rhs_view = to_cl_tile_view(rhs);
- const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_tile.info().data_type()));
+ const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_view.data_type()));
CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || op == BinaryOp::GreaterEqual || op == BinaryOp::Greater);
- CKW_ASSERT(lhs_tile.is_scalar());
- CKW_ASSERT(rhs_tile.is_scalar());
+ CKW_ASSERT(lhs_view.is_scalar());
+ CKW_ASSERT(rhs_view.is_scalar());
if(is_else_if)
{
append_code("else ");
}
- append_code("if (", lhs_tile.scalar(0, 0).str, " ", op_name, " ", rhs_tile.scalar(0, 0).str, ")\n{\n");
+ append_code("if (", lhs_view.scalar(0, 0).str, " ", op_name, " ", rhs_view.scalar(0, 0).str, ")\n{\n");
write_body(body);
append_code("}\n");
}
@@ -358,25 +359,25 @@ void CLKernelWriter::op_for_loop(
const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
const std::function<void()> &body)
{
- const auto &var_tile = to_cl_tile(var);
- const auto &cond_value_tile = to_cl_tile(cond_value);
- const auto &update_var_tile = to_cl_tile(update_var);
- const auto &update_value_tile = to_cl_tile(update_value);
+ const auto var_view = to_cl_tile_view(var);
+ const auto cond_value_view = to_cl_tile_view(cond_value);
+ const auto update_var_view = to_cl_tile_view(update_var);
+ const auto update_value_view = to_cl_tile_view(update_value);
- CKW_ASSERT(var_tile.is_scalar());
- CKW_ASSERT(cond_value_tile.is_scalar());
- CKW_ASSERT(update_var_tile.is_scalar());
- CKW_ASSERT(update_value_tile.is_scalar());
+ CKW_ASSERT(var_view.is_scalar());
+ CKW_ASSERT(cond_value_view.is_scalar());
+ CKW_ASSERT(update_var_view.is_scalar());
+ CKW_ASSERT(update_value_view.is_scalar());
- CKW_ASSERT(var_tile.info().data_type() == cond_value_tile.info().data_type());
- CKW_ASSERT(update_var_tile.info().data_type() == update_value_tile.info().data_type());
+ CKW_ASSERT(var_view.data_type() == cond_value_view.data_type());
+ CKW_ASSERT(update_var_view.data_type() == update_value_view.data_type());
- const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_tile.info().data_type()));
+ const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_view.data_type()));
CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater);
append_code(
- "for (; ", var_tile.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_tile.scalar(0, 0).str, "; ",
- update_var_tile.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_tile.scalar(0, 0).str, ")\n{\n");
+ "for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ",
+ update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_view.scalar(0, 0).str, ")\n{\n");
write_body(body);
append_code("}\n");
}
@@ -388,14 +389,14 @@ void CLKernelWriter::op_return()
void CLKernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim)
{
- const auto &tile = to_cl_tile(dst);
+ const auto tile_view = to_cl_tile_view(dst);
- CKW_ASSERT(tile.is_scalar());
- CKW_ASSERT(tile.info().data_type() == DataType::Int32 || tile.info().data_type() == DataType::Uint32);
+ CKW_ASSERT(tile_view.is_scalar());
+ CKW_ASSERT(tile_view.data_type() == DataType::Int32 || tile_view.data_type() == DataType::Uint32);
CKW_ASSERT(dim >= 0 && dim <= 2);
- append_code(tile.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n");
+ append_code(tile_view.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n");
}
void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileOperand> &operands)
@@ -405,13 +406,12 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
for(auto &op : operands)
{
- const auto &tile = to_cl_tile(op);
- const auto &info = tile.info();
+ const auto tile_view = to_cl_tile_view(op);
- const auto &name = tile.name();
- const auto width = info.width();
- const auto height = info.height();
- const auto data_type = info.data_type();
+ const auto name = tile_view.name();
+ const auto width = tile_view.width();
+ const auto height = tile_view.height();
+ const auto data_type = tile_view.data_type();
// Construct the format specifier to print out one row of the tile.
std::string row_format("%");
@@ -479,7 +479,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
// Construct the variable arguments for the printf statement.
for(int32_t row = 0; row < height; ++row)
{
- args_code += ", " + tile.vector(row).str;
+ args_code += ", " + tile_view.vector(row).str;
}
}
@@ -566,9 +566,11 @@ void CLKernelWriter::op_write_raw_code(const std::string &raw_code)
append_code(raw_code);
}
-const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const
+TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) const
{
- const auto &tile = get_tile(operand);
+ const auto tile_and_area = get_tile(operand);
+ ITile &tile = std::get<0>(tile_and_area);
+ const TileArea area = std::get<1>(tile_and_area);
#ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
// Check if the tile is a CLTile created by this kernel writer.
@@ -620,7 +622,7 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const
}
#endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
- return static_cast<const CLTile &>(tile);
+ return { static_cast<CLTile &>(tile), area };
}
void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
@@ -636,10 +638,10 @@ void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOpe
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y)
{
- const auto &dil_x_tile = to_cl_tile(dilation_x);
- const auto &dil_y_tile = to_cl_tile(dilation_y);
+ const auto dil_x_view = to_cl_tile_view(dilation_x);
+ const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
}
void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
@@ -655,14 +657,14 @@ void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const Tile
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y)
{
- const auto &dil_x_tile = to_cl_tile(dilation_x);
- const auto &dil_y_tile = to_cl_tile(dilation_y);
+ const auto dil_x_view = to_cl_tile_view(dilation_x);
+ const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
}
void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
{
const CLTile dilation_x({ { "1" } }, DataType::Int32);
const CLTile dilation_y({ { "1" } }, DataType::Int32);
@@ -671,8 +673,8 @@ void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOp
}
void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const CLTile &dilation_x, const CLTile &dilation_y, bool indirect_buffer)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer)
{
CKW_UNUSED(dilation_x);
CKW_ASSERT(dilation_x.is_scalar());
@@ -681,7 +683,7 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
if(indirect_buffer)
{
- CKW_ASSERT(dilation_y.scalar(0,0).str == "((int)(1))" && dilation_x.scalar(0,0).str == "((int)(1))");
+ CKW_ASSERT(dilation_y.scalar(0, 0).str == "((int)(1))" && dilation_x.scalar(0, 0).str == "((int)(1))");
}
ITensor &tensor = get_tensor(tensor_op);
@@ -700,11 +702,12 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
CKW_THROW_MSG("Unsupported tensor storage");
}
- const auto &tile = to_cl_tile(tile_op);
- const auto &x_tile = to_cl_tile(x);
- const auto &y_tile = to_cl_tile(y);
- const auto &z_tile = to_cl_tile(z);
- const auto &batch_tile = to_cl_tile(batch);
+ // Load/store op doesn't support sub-tile access.
+ const auto tile = to_cl_tile_view(tile_op).full_tile();
+ const auto x_tile = to_cl_tile_view(x).full_tile();
+ const auto y_tile = to_cl_tile_view(y).full_tile();
+ const auto z_tile = to_cl_tile_view(z).full_tile();
+ const auto batch_tile = to_cl_tile_view(batch).full_tile();
CKW_ASSERT(x_tile.is_scalar());
CKW_ASSERT(z_tile.is_scalar());