aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-09-19 16:41:34 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-09-22 12:07:09 +0000
commitcd1f03e765ad0f3ca3b68b1a7c1d0a1539cab439 (patch)
tree9cb78579e01e14501c316f5297c804ba13c8ad37
parent1f841a52f9a7f52948d676bc3807461bbed6f70a (diff)
downloadComputeLibrary-cd1f03e765ad0f3ca3b68b1a7c1d0a1539cab439.tar.gz
Add row vector and scalar access support to tile operand
* Add the concept of tile view which refers to a specific rectangular area of the tile object. - The active area is added to TileOperand so that the user can access part of the tile. - Currently only row vector and scalar access are exposed to the user. - All writing operations except load/store op support sub-tile. * Add tests for sub-tile access. Resolves: COMPMID-6557 Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ica3f9eaf17f06e080c495d36c572f623b62c2910 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10354 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--compute_kernel_writer/CMakeLists.txt1
-rw-r--r--compute_kernel_writer/include/ckw/KernelWriter.h8
-rw-r--r--compute_kernel_writer/include/ckw/TileOperand.h46
-rw-r--r--compute_kernel_writer/src/KernelWriter.cpp7
-rw-r--r--compute_kernel_writer/src/TileOperand.cpp38
-rw-r--r--compute_kernel_writer/src/TileView.cpp57
-rw-r--r--compute_kernel_writer/src/TileView.h189
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp229
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h11
-rw-r--r--compute_kernel_writer/src/cl/CLTile.cpp5
-rw-r--r--compute_kernel_writer/validation/Validation.cpp3
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterSubTileTest.h264
12 files changed, 732 insertions, 126 deletions
diff --git a/compute_kernel_writer/CMakeLists.txt b/compute_kernel_writer/CMakeLists.txt
index 170347c51c..69a4cdd51f 100644
--- a/compute_kernel_writer/CMakeLists.txt
+++ b/compute_kernel_writer/CMakeLists.txt
@@ -133,6 +133,7 @@ target_sources(ckw PRIVATE
src/TensorUtils.cpp
src/TileInfo.cpp
src/TileOperand.cpp
+ src/TileView.cpp
)
if(CKW_ENABLE_OPENCL)
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h
index 93ae8aecd6..15c99fe652 100644
--- a/compute_kernel_writer/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/include/ckw/KernelWriter.h
@@ -34,6 +34,7 @@
#include <functional>
#include <memory>
#include <string>
+#include <tuple>
namespace ckw
{
@@ -42,6 +43,7 @@ namespace ckw
class Kernel;
class TensorInfo;
class TensorSampler;
+class TileArea;
class TileInfo;
enum class DataType;
@@ -313,7 +315,7 @@ public:
* @param[in] batch batch
*/
virtual void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch_op) = 0;
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch_op) = 0;
protected:
// =============================================================================================
@@ -355,8 +357,8 @@ protected:
/** Create a new tile operand referring to the specified tile object. */
static TileOperand create_tile_operand(ITile &tile);
- /** Get the reference to tile object from the tile operand. */
- static ITile &get_tile(const TileOperand &operand);
+ /** Get the reference to the tile object and the active area from the tile operand. */
+ static std::tuple<ITile &, TileArea> get_tile(const TileOperand &operand);
/** Create a new tensor operand from a tensor object. */
static TensorOperand create_tensor_operand(ITensor &tensor);
diff --git a/compute_kernel_writer/include/ckw/TileOperand.h b/compute_kernel_writer/include/ckw/TileOperand.h
index 873a9825f3..56dc5e7b2b 100644
--- a/compute_kernel_writer/include/ckw/TileOperand.h
+++ b/compute_kernel_writer/include/ckw/TileOperand.h
@@ -25,6 +25,8 @@
#ifndef CKW_INCLUDE_CKW_TILEOPERAND_H
#define CKW_INCLUDE_CKW_TILEOPERAND_H
+#include <cstdint>
+
namespace ckw
{
@@ -41,13 +43,55 @@ public:
friend class KernelWriter;
friend class TensorOperand;
+ /** Get a row vector of the current tile operand.
+ *
+ * @param[in] row The index of the row to be accessed in the current tile operand.
+ *
+ * @return A new tile operand referring to a row of the current tile operand.
+ */
+ TileOperand row(int32_t row) const;
+
+ /** Get a scalar element of the current tile operand.
+ *
+ * @param[in] row The index of the row to be accessed in the current tile operand.
+ * @param[in] col The index of the column to be accessed in the current tile operand.
+ *
+ * @return A new tile operand referring to a scalar element of the current tile operand.
+ */
+ TileOperand scalar(int32_t row, int32_t col) const;
+
private:
// These are hidden from the public API to avoid any misuse.
/** Initialize a new instance of @ref TileOperand class for the given tile. */
TileOperand(ITile &tile);
- ITile &_tile;
+ /** Initialize a new instance of @ref TileOperand class that is the sub-tile of the given tile. */
+ TileOperand(const TileOperand &operand, int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end);
+
+ /** Get a sub-tile of the current tile operand.
+ *
+ * The range of rows and columns is defined by pairs of start and end indices, inclusive lower and exclusive upper.
+ * In other words, any row and column indices satisfying the following conditions will be part of the sub-tile:
+ *
+ * row_start <= row_index < row_end
+ * col_start <= col_index < col_end
+ *
+ * @param[in] row_start The start index of the row range.
+ * @param[in] row_end The end index of the row range.
+ * @param[in] col_start The start index of the column range.
+ * @param[in] col_end The end index of the column range.
+ *
+ * @return A new tile operand refering to the same tile but with the new active area.
+ */
+ TileOperand tile(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end) const;
+
+ ITile *_tile;
+
+ int32_t _row_start;
+ int32_t _row_end;
+ int32_t _col_start;
+ int32_t _col_end;
};
} // namespace ckw
diff --git a/compute_kernel_writer/src/KernelWriter.cpp b/compute_kernel_writer/src/KernelWriter.cpp
index fb0e62c8ce..0bea1200b7 100644
--- a/compute_kernel_writer/src/KernelWriter.cpp
+++ b/compute_kernel_writer/src/KernelWriter.cpp
@@ -27,10 +27,13 @@
#include "ckw/TileOperand.h"
#include "ckw/types/TargetArchitecture.h"
#include "ckw/types/TargetLanguage.h"
+#include "src/TileView.h"
#include "src/cl/CLKernelWriter.h"
#include "src/cl/CLTensorArgument.h"
#include "src/cl/CLTile.h"
+#include <tuple>
+
namespace ckw
{
@@ -90,9 +93,9 @@ TileOperand KernelWriter::create_tile_operand(ITile &tile)
return TileOperand(tile);
}
-ITile &KernelWriter::get_tile(const TileOperand &operand)
+std::tuple<ITile &, TileArea> KernelWriter::get_tile(const TileOperand &operand)
{
- return operand._tile;
+ return { *operand._tile, { operand._row_start, operand._row_end, operand._col_start, operand._col_end } };
}
TensorOperand KernelWriter::create_tensor_operand(ITensor &tensor)
diff --git a/compute_kernel_writer/src/TileOperand.cpp b/compute_kernel_writer/src/TileOperand.cpp
index 7d180feec8..3dfa2b8b2b 100644
--- a/compute_kernel_writer/src/TileOperand.cpp
+++ b/compute_kernel_writer/src/TileOperand.cpp
@@ -23,13 +23,49 @@
*/
#include "ckw/TileOperand.h"
+#include "ckw/Error.h"
+#include "src/ITile.h"
namespace ckw
{
TileOperand::TileOperand(ITile &tile)
- : _tile(tile)
+ : _tile(&tile), _row_start(0), _row_end(tile.info().height()), _col_start(0), _col_end(tile.info().width())
{
}
+TileOperand::TileOperand(const TileOperand &operand, int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end)
+ : _tile(operand._tile), _row_start(row_start), _row_end(row_end), _col_start(col_start), _col_end(col_end)
+{
+ CKW_ASSERT(row_start >= 0 && row_start < _tile->info().height());
+ CKW_ASSERT(row_end > row_start && row_end <= _tile->info().height());
+ CKW_ASSERT(col_start >= 0 && col_start < _tile->info().width());
+ CKW_ASSERT(col_end > col_start && col_end <= _tile->info().width());
+}
+
+TileOperand TileOperand::tile(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end) const
+{
+ CKW_ASSERT(row_start >= 0 && _row_start + row_start < _row_end);
+ CKW_ASSERT(row_end > row_start && _row_start + row_end <= _row_end);
+ CKW_ASSERT(col_start >= 0 && _col_start + col_start < _col_end);
+ CKW_ASSERT(col_end > col_start && _col_start + col_end <= _col_end);
+
+ return TileOperand(*this, _row_start + row_start, _row_start + row_end, _col_start + col_start, _col_start + col_end);
+}
+
+TileOperand TileOperand::row(int32_t row) const
+{
+ CKW_ASSERT(row >= 0 && _row_start + row < _row_end);
+
+ return tile(_row_start + row, _row_start + row + 1, _col_start, _col_end);
+}
+
+TileOperand TileOperand::scalar(int32_t row, int32_t col) const
+{
+ CKW_ASSERT(row >= 0 && _row_start + row < _row_end);
+ CKW_ASSERT(col >= 0 && _col_start + col < _col_end);
+
+ return tile(_row_start + row, _row_start + row + 1, _col_start + col, _col_start + col + 1);
+}
+
} // namespace ckw
diff --git a/compute_kernel_writer/src/TileView.cpp b/compute_kernel_writer/src/TileView.cpp
new file mode 100644
index 0000000000..ea803f92f4
--- /dev/null
+++ b/compute_kernel_writer/src/TileView.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/TileView.h"
+
+#include <cstdint>
+
+namespace ckw
+{
+
+TileArea::TileArea(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end)
+ : _row_start(row_start), _row_end(row_end), _col_start(col_start), _col_end(col_end)
+{
+}
+
+int32_t TileArea::row_start() const
+{
+ return _row_start;
+}
+
+int32_t TileArea::row_end() const
+{
+ return _row_end;
+}
+
+int32_t TileArea::col_start() const
+{
+ return _col_start;
+}
+
+int32_t TileArea::col_end() const
+{
+ return _col_end;
+}
+
+} // namespace ckw
diff --git a/compute_kernel_writer/src/TileView.h b/compute_kernel_writer/src/TileView.h
new file mode 100644
index 0000000000..e0d034fa8d
--- /dev/null
+++ b/compute_kernel_writer/src/TileView.h
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_SRC_TILEVIEW_H
+#define CKW_SRC_TILEVIEW_H
+
+#include "ckw/Error.h"
+#include "ckw/types/DataType.h"
+#include "src/ITile.h"
+
+#include <cstdint>
+
+namespace ckw
+{
+
+/** A rectangular active area of a tile. */
+class TileArea
+{
+public:
+ /** Create a new tile rectangular active area.
+ *
+ * The range of rows and columns is defined by pairs of start and end indices, inclusive lower and exclusive upper.
+ * In other word, any row and column indices satisfied the following conditions will be part of the active area:
+ *
+ * row_start <= row_index < row_end
+ * col_start <= col_index < col_end
+ *
+ * @param[in] row_start The start index of the row range.
+ * @param[in] row_end The end index of the row range.
+ * @param[in] col_start The start index of the column range.
+ * @param[in] col_end The end index of the column range.
+ */
+ TileArea(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end);
+
+ /** Get the start row index. */
+ int32_t row_start() const;
+
+ /** Get the end row (exclusive) index. */
+ int32_t row_end() const;
+
+ /** Get the start column index. */
+ int32_t col_start() const;
+
+ /** Get the end column (exclusive) index. */
+ int32_t col_end() const;
+
+private:
+ int32_t _row_start;
+ int32_t _row_end;
+ int32_t _col_start;
+ int32_t _col_end;
+};
+
+/** A rectangular view of a tile. */
+template <typename T>
+class TileView
+{
+public:
+ /** Create a tile view that refers to the whole tile.
+ *
+ * @param[in] tile The tile object.
+ */
+ TileView(const T &tile)
+ : _tile(&tile), _area(0, tile.info().height(), 0, tile.info().width())
+ {
+ }
+
+ /** Create a new rectangular view of the given tile.
+ *
+ * @param[in] tile The tile object.
+ * @param[in] area The rectangular active area.
+ */
+ TileView(const T &tile, const TileArea &area)
+ : _tile(&tile), _area(area)
+ {
+ }
+
+ /** Get the tile object.
+ *
+ * The caller must guarantee that the tile view refers to the whole tile.
+ */
+ const T &full_tile() const
+ {
+ CKW_ASSERT(is_full_tile());
+
+ return *_tile;
+ }
+
+ /** Get the data type of the tile. */
+ DataType data_type() const
+ {
+ return _tile->info().data_type();
+ }
+
+ /** Get the start row index. */
+ int32_t row_start() const
+ {
+ return _area.row_start();
+ }
+
+ /** Get the end row index. */
+ int32_t row_end() const
+ {
+ return _area.row_end();
+ }
+
+ /** Get the start column index. */
+ int32_t col_start() const
+ {
+ return _area.col_start();
+ }
+
+ /** Get the end column index. */
+ int32_t col_end() const
+ {
+ return _area.col_end();
+ }
+
+ /** Get the height of the tile view. */
+ int32_t height() const
+ {
+ return _area.row_end() - _area.row_start();
+ }
+
+ /** Get the width of the tile view. */
+ int32_t width() const
+ {
+ return _area.col_end() - _area.col_start();
+ }
+
+ /** See @ref IVectorAccess::vector. */
+ TileVariable vector(int32_t row) const
+ {
+ return _tile->vector(row_start() + row, col_start(), width());
+ }
+
+ /** See @ref IScalarAccess::scalar. */
+ TileVariable scalar(int32_t row, int32_t col) const
+ {
+ return _tile->scalar(row_start() + row, col_start() + col);
+ }
+
+ /** Get the name of the tile. */
+ const std::string &name() const
+ {
+ return _tile->name();
+ }
+
+ /** Get whether the tile view is a scalar element. */
+ bool is_scalar() const
+ {
+ return height() == 1 && width() == 1;
+ }
+
+ /** Get whether the tile view refers to the whole tile. */
+ bool is_full_tile() const
+ {
+ return row_start() == 0 && row_end() == _tile->info().height() && col_start() == 0 && col_end() == _tile->info().width();
+ }
+
+private:
+ const T *_tile;
+ TileArea _area;
+};
+
+} // namespace ckw
+
+#endif // CKW_SRC_TILEVIEW_H
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 4074da7912..2db9c139b7 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -32,6 +32,7 @@
#include "ckw/types/MemoryOperation.h"
#include "ckw/types/TargetLanguage.h"
#include "src/ITensorComponent.h"
+#include "src/TileView.h"
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
#include "src/cl/CLTile.h"
@@ -42,6 +43,7 @@
#include <algorithm>
#include <cstdint>
+#include <tuple>
#include <vector>
namespace ckw
@@ -112,39 +114,39 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w);
const auto broadcast_src_x = dst_w != 1 && src_w == 1;
const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", src_prefix, src_tile.vector(y).str, ";\n");
+ append_code(dst_view.vector(y).str, " = ", src_prefix, src_view.vector(y).str, ";\n");
}
}
void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto dst_type = dst_tile.info().data_type();
+ const auto dst_type = dst_view.data_type();
const auto convert_type_str = cl_get_variable_datatype_as_string(dst_type, src_w);
const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w);
@@ -155,27 +157,27 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
const auto broadcast_x = dst_w != 1 && src_w == 1;
const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() != dst_tile.info().data_type(), "Source and destination type must be different.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_tile.vector(y).str, ");\n");
+ append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_view.vector(y).str, ");\n");
}
}
void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &src_tile = to_cl_tile(src);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto src_view = to_cl_tile_view(src);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto src_w = src_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto src_w = src_view.width();
- const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w);
const auto broadcast_src_x = dst_w != 1 && src_w == 1;
const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
@@ -186,35 +188,34 @@ void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOper
const auto op_prefix = op_is_func ? op_name + "(" : op_name;
const auto op_suffix = op_is_func ? ")" : "";
- CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
- CKW_ASSERT_MSG(!is_data_type_float(src_tile.info().data_type()), "Logical and bitwise not only work with integer.");
- CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, " = ", src_prefix, op_prefix, src_tile.vector(y).str, op_suffix, ";\n");
+ append_code(dst_view.vector(y).str, " = ", src_prefix, op_prefix, src_view.vector(y).str, op_suffix, ";\n");
}
}
void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &lhs_tile = to_cl_tile(first);
- const auto &rhs_tile = to_cl_tile(second);
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto lhs_view = to_cl_tile_view(first);
+ const auto rhs_view = to_cl_tile_view(second);
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto lhs_w = lhs_tile.info().width();
- const auto rhs_w = rhs_tile.info().width();
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto lhs_w = lhs_view.width();
+ const auto rhs_w = rhs_view.width();
- const auto data_type = lhs_tile.info().data_type();
+ const auto data_type = lhs_view.data_type();
- CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match.");
+ CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match.");
- CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension.");
CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension.");
@@ -230,10 +231,10 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
for(int32_t k = 0; k < lhs_w; ++k)
{
append_code(
- dst_tile.scalar(x, y).str, " = fma(",
- lhs_tile.scalar(k, y).str, ", ",
- rhs_tile.scalar(k, x).str, ", ",
- dst_tile.scalar(x, y).str, ");\n");
+ dst_view.scalar(x, y).str, " = fma(",
+ lhs_view.scalar(k, y).str, ", ",
+ rhs_view.scalar(k, x).str, ", ",
+ dst_view.scalar(x, y).str, ");\n");
}
}
}
@@ -259,25 +260,25 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp
// Broadcasting on y dimension is automatic (see CLTile::vector).
for(int32_t y = 0; y < dst_h; ++y)
{
- append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix);
+ append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, rhs_view.vector(y).str, op_suffix);
}
}
}
void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
{
- const auto &dst_tile = to_cl_tile(dst);
- const auto &first_tile = to_cl_tile(first);
- const auto &second_tile = to_cl_tile(second);
- const auto &third_tile = to_cl_tile(third);
-
- const auto dst_w = dst_tile.info().width();
- const auto dst_h = dst_tile.info().height();
- const auto first_w = first_tile.info().width();
- const auto second_w = second_tile.info().width();
- const auto third_w = third_tile.info().width();
-
- const auto data_type = dst_tile.info().data_type();
+ const auto dst_view = to_cl_tile_view(dst);
+ const auto first_view = to_cl_tile_view(first);
+ const auto second_view = to_cl_tile_view(second);
+ const auto third_view = to_cl_tile_view(third);
+
+ const auto dst_w = dst_view.width();
+ const auto dst_h = dst_view.height();
+ const auto first_w = first_view.width();
+ const auto second_w = second_view.width();
+ const auto third_w = third_view.width();
+
+ const auto data_type = dst_view.data_type();
const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
const auto op_info = cl_get_ternary_op(op);
@@ -293,12 +294,12 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile
const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : "";
CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function.");
- CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match.");
- CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match.");
+ CKW_ASSERT_MSG(second_view.data_type() == dst_view.data_type(), "2nd source and destination type must match.");
+ CKW_ASSERT_MSG(third_view.data_type() == dst_view.data_type(), "3rd source and destination type must match.");
- CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
- CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension.");
CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension.");
@@ -308,30 +309,30 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile
for(int32_t y = 0; y < dst_h; ++y)
{
append_code(
- dst_tile.vector(y).str, " = ", op_name, "(",
- first_prefix, first_tile.vector(y).str, ", ",
- second_prefix, second_tile.vector(y).str, ", ",
- third_prefix, third_tile.vector(y).str, ");\n");
+ dst_view.vector(y).str, " = ", op_name, "(",
+ first_prefix, first_view.vector(y).str, ", ",
+ second_prefix, second_view.vector(y).str, ", ",
+ third_prefix, third_view.vector(y).str, ");\n");
}
}
void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if)
{
- const auto &lhs_tile = to_cl_tile(lhs);
- const auto &rhs_tile = to_cl_tile(rhs);
+ const auto lhs_view = to_cl_tile_view(lhs);
+ const auto rhs_view = to_cl_tile_view(rhs);
- const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_tile.info().data_type()));
+ const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_view.data_type()));
CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || op == BinaryOp::GreaterEqual || op == BinaryOp::Greater);
- CKW_ASSERT(lhs_tile.is_scalar());
- CKW_ASSERT(rhs_tile.is_scalar());
+ CKW_ASSERT(lhs_view.is_scalar());
+ CKW_ASSERT(rhs_view.is_scalar());
if(is_else_if)
{
append_code("else ");
}
- append_code("if (", lhs_tile.scalar(0, 0).str, " ", op_name, " ", rhs_tile.scalar(0, 0).str, ")\n{\n");
+ append_code("if (", lhs_view.scalar(0, 0).str, " ", op_name, " ", rhs_view.scalar(0, 0).str, ")\n{\n");
write_body(body);
append_code("}\n");
}
@@ -358,25 +359,25 @@ void CLKernelWriter::op_for_loop(
const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
const std::function<void()> &body)
{
- const auto &var_tile = to_cl_tile(var);
- const auto &cond_value_tile = to_cl_tile(cond_value);
- const auto &update_var_tile = to_cl_tile(update_var);
- const auto &update_value_tile = to_cl_tile(update_value);
+ const auto var_view = to_cl_tile_view(var);
+ const auto cond_value_view = to_cl_tile_view(cond_value);
+ const auto update_var_view = to_cl_tile_view(update_var);
+ const auto update_value_view = to_cl_tile_view(update_value);
- CKW_ASSERT(var_tile.is_scalar());
- CKW_ASSERT(cond_value_tile.is_scalar());
- CKW_ASSERT(update_var_tile.is_scalar());
- CKW_ASSERT(update_value_tile.is_scalar());
+ CKW_ASSERT(var_view.is_scalar());
+ CKW_ASSERT(cond_value_view.is_scalar());
+ CKW_ASSERT(update_var_view.is_scalar());
+ CKW_ASSERT(update_value_view.is_scalar());
- CKW_ASSERT(var_tile.info().data_type() == cond_value_tile.info().data_type());
- CKW_ASSERT(update_var_tile.info().data_type() == update_value_tile.info().data_type());
+ CKW_ASSERT(var_view.data_type() == cond_value_view.data_type());
+ CKW_ASSERT(update_var_view.data_type() == update_value_view.data_type());
- const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_tile.info().data_type()));
+ const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_view.data_type()));
CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater);
append_code(
- "for (; ", var_tile.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_tile.scalar(0, 0).str, "; ",
- update_var_tile.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_tile.scalar(0, 0).str, ")\n{\n");
+ "for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ",
+ update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_view.scalar(0, 0).str, ")\n{\n");
write_body(body);
append_code("}\n");
}
@@ -388,14 +389,14 @@ void CLKernelWriter::op_return()
void CLKernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim)
{
- const auto &tile = to_cl_tile(dst);
+ const auto tile_view = to_cl_tile_view(dst);
- CKW_ASSERT(tile.is_scalar());
- CKW_ASSERT(tile.info().data_type() == DataType::Int32 || tile.info().data_type() == DataType::Uint32);
+ CKW_ASSERT(tile_view.is_scalar());
+ CKW_ASSERT(tile_view.data_type() == DataType::Int32 || tile_view.data_type() == DataType::Uint32);
CKW_ASSERT(dim >= 0 && dim <= 2);
- append_code(tile.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n");
+ append_code(tile_view.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n");
}
void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileOperand> &operands)
@@ -405,13 +406,12 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
for(auto &op : operands)
{
- const auto &tile = to_cl_tile(op);
- const auto &info = tile.info();
+ const auto tile_view = to_cl_tile_view(op);
- const auto &name = tile.name();
- const auto width = info.width();
- const auto height = info.height();
- const auto data_type = info.data_type();
+ const auto name = tile_view.name();
+ const auto width = tile_view.width();
+ const auto height = tile_view.height();
+ const auto data_type = tile_view.data_type();
// Construct the format specifier to print out one row of the tile.
std::string row_format("%");
@@ -479,7 +479,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO
// Construct the variable arguments for the printf statement.
for(int32_t row = 0; row < height; ++row)
{
- args_code += ", " + tile.vector(row).str;
+ args_code += ", " + tile_view.vector(row).str;
}
}
@@ -566,9 +566,11 @@ void CLKernelWriter::op_write_raw_code(const std::string &raw_code)
append_code(raw_code);
}
-const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const
+TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) const
{
- const auto &tile = get_tile(operand);
+ const auto tile_and_area = get_tile(operand);
+ ITile &tile = std::get<0>(tile_and_area);
+ const TileArea area = std::get<1>(tile_and_area);
#ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
// Check if the tile is a CLTile created by this kernel writer.
@@ -620,7 +622,7 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const
}
#endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
- return static_cast<const CLTile &>(tile);
+ return { static_cast<CLTile &>(tile), area };
}
void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
@@ -636,10 +638,10 @@ void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOpe
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y)
{
- const auto &dil_x_tile = to_cl_tile(dilation_x);
- const auto &dil_y_tile = to_cl_tile(dilation_y);
+ const auto dil_x_view = to_cl_tile_view(dilation_x);
+ const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
}
void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
@@ -655,14 +657,14 @@ void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const Tile
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y)
{
- const auto &dil_x_tile = to_cl_tile(dilation_x);
- const auto &dil_y_tile = to_cl_tile(dilation_y);
+ const auto dil_x_view = to_cl_tile_view(dilation_x);
+ const auto dil_y_view = to_cl_tile_view(dilation_y);
- op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */);
+ op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */);
}
void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
{
const CLTile dilation_x({ { "1" } }, DataType::Int32);
const CLTile dilation_y({ { "1" } }, DataType::Int32);
@@ -671,8 +673,8 @@ void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOp
}
void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const CLTile &dilation_x, const CLTile &dilation_y, bool indirect_buffer)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer)
{
CKW_UNUSED(dilation_x);
CKW_ASSERT(dilation_x.is_scalar());
@@ -681,7 +683,7 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
if(indirect_buffer)
{
- CKW_ASSERT(dilation_y.scalar(0,0).str == "((int)(1))" && dilation_x.scalar(0,0).str == "((int)(1))");
+ CKW_ASSERT(dilation_y.scalar(0, 0).str == "((int)(1))" && dilation_x.scalar(0, 0).str == "((int)(1))");
}
ITensor &tensor = get_tensor(tensor_op);
@@ -700,11 +702,12 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
CKW_THROW_MSG("Unsupported tensor storage");
}
- const auto &tile = to_cl_tile(tile_op);
- const auto &x_tile = to_cl_tile(x);
- const auto &y_tile = to_cl_tile(y);
- const auto &z_tile = to_cl_tile(z);
- const auto &batch_tile = to_cl_tile(batch);
+ // Load/store op doesn't support sub-tile access.
+ const auto tile = to_cl_tile_view(tile_op).full_tile();
+ const auto x_tile = to_cl_tile_view(x).full_tile();
+ const auto y_tile = to_cl_tile_view(y).full_tile();
+ const auto z_tile = to_cl_tile_view(z).full_tile();
+ const auto batch_tile = to_cl_tile_view(batch).full_tile();
CKW_ASSERT(x_tile.is_scalar());
CKW_ASSERT(z_tile.is_scalar());
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index 1e2e5dc910..d7cf24d5e6 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -26,6 +26,7 @@
#define CKW_SRC_CL_CLKERNELWRITER_H
#include "ckw/KernelWriter.h"
+#include "src/TileView.h"
#include <memory>
#include <set>
@@ -150,14 +151,14 @@ public:
const TileOperand &dilation_x, const TileOperand &dilation_y) override;
void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
protected:
- /** Return @ref CLTile object from the @ref TileOperand object.
+ /** Return a tile view containing a reference to @ref CLTile object and the active area.
*
* This function performs appropriate check before doing type casting.
*/
- const CLTile &to_cl_tile(const TileOperand &operand) const;
+ TileView<CLTile> to_cl_tile_view(const TileOperand &operand) const;
/** Append the specified code to the kernel body source code. */
template <typename T, typename... TArgs>
@@ -181,8 +182,8 @@ protected:
private:
/** Helper method to consolidate all load/store logic in this class */
void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const CLTile &dilation_x, const CLTile &dilation_y, bool indirect_buffer);
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer);
/** This function is the generic function to write both `if` and `else if` blocks.
*
diff --git a/compute_kernel_writer/src/cl/CLTile.cpp b/compute_kernel_writer/src/cl/CLTile.cpp
index 556db0f47b..0cce69a9e1 100644
--- a/compute_kernel_writer/src/cl/CLTile.cpp
+++ b/compute_kernel_writer/src/cl/CLTile.cpp
@@ -125,6 +125,9 @@ TileVariable CLTile::vector(int32_t row) const
TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const
{
+ CKW_ASSERT(col_start >= 0 && col_start < _info.width());
+ CKW_ASSERT(col_start + width <= _info.width());
+
// Validate the new vector length
cl_validate_vector_length(width);
@@ -154,7 +157,7 @@ TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const
{
t.str = create_var_name(row);
- if(_info.width() != 1)
+ if(_info.width() != 1 && _info.width() != width)
{
t.str += ".s";
for(int i = 0; i < width; ++i)
diff --git a/compute_kernel_writer/validation/Validation.cpp b/compute_kernel_writer/validation/Validation.cpp
index 7031fe80a9..4fbd1eacda 100644
--- a/compute_kernel_writer/validation/Validation.cpp
+++ b/compute_kernel_writer/validation/Validation.cpp
@@ -37,6 +37,7 @@
#include "validation/tests/CLKernelWriterOpLoadStoreTest.h"
#include "validation/tests/CLKernelWriterPrintTest.h"
#include "validation/tests/CLKernelWriterReturnTest.h"
+#include "validation/tests/CLKernelWriterSubTileTest.h"
#include "validation/tests/CLKernelWriterTernaryOpTest.h"
#include "validation/tests/CLKernelWriterUnaryExpressionTest.h"
#include "validation/tests/CLTensorArgumentTest.h"
@@ -102,6 +103,7 @@ int32_t main()
const auto test35 = std::make_unique<CLKernelWriterGetGlobalIdTest>();
const auto test36 = std::make_unique<CLKernelWriterPrintTest>();
const auto test37 = std::make_unique<CLKernelWriterOpLoadIndirectTest>();
+ const auto test38 = std::make_unique<CLKernelWriterSubTileTest>();
tests.push_back(test3.get());
tests.push_back(test4.get());
@@ -140,6 +142,7 @@ int32_t main()
tests.push_back(test35.get());
tests.push_back(test36.get());
tests.push_back(test37.get());
+ tests.push_back(test38.get());
#endif /* COMPUTE_KERNEL_WRITER_OPENCL_ENABLED */
bool all_test_passed = true;
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterSubTileTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterSubTileTest.h
new file mode 100644
index 0000000000..ea360b289e
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterSubTileTest.h
@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H
+#define CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Operators.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterSubTileTest : public ITest
+{
+public:
+ CLKernelWriterSubTileTest()
+ {
+ // These are the definitions of the tiles involving in the writing actions.
+ //
+ // Structure:
+ // * List of tiles:
+ // - Tile full height.
+ // - Tile full width.
+ // - Tile view access type (full tile, vector, scalar).
+ // - Tile view start row.
+ // - Tile view start column.
+ // - The tile name.
+
+ // Vector access.
+ _tests.push_back(
+ { { { 1, 4, AccessType::Vector, 0, 0, "{tile_name}" }, //
+ { 4, 4, AccessType::Vector, 2, 0, "{tile_name}__2" },
+ { 1, 4, AccessType::Full, 0, 0, "{tile_name}" },
+ { 4, 4, AccessType::Vector, 3, 0, "{tile_name}__3" } } });
+
+ // Scalar access.
+ _tests.push_back(
+ { { { 1, 1, AccessType::Full, 0, 0, "{tile_name}" }, //
+ { 4, 8, AccessType::Scalar, 2, 4, "{tile_name}__2.s4" },
+ { 1, 16, AccessType::ScalarOfVector, 0, 10, "{tile_name}.sA" },
+ { 1, 1, AccessType::Scalar, 0, 0, "{tile_name}" } } });
+
+ // These are the definitions of the writing actions.
+ //
+ // Structure:
+ // * Writing function.
+ // * Whether this function only works with scalar value.
+ // * Expected code format.
+
+ _actions.push_back(
+ { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
+ {
+ writer.op_assign(args.at(0), args.at(1));
+ },
+ false,
+ "{op0} = {op1};\n" });
+
+ _actions.push_back(
+ { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
+ {
+ writer.op_unary(args.at(0), UnaryOp::Sqrt, args.at(1));
+ },
+ false,
+ "{op0} = sqrt({op1});\n" });
+
+ _actions.push_back(
+ { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
+ {
+ writer.op_binary(args.at(0), BinaryOp::Add, args.at(1), args.at(2));
+ },
+ false,
+ "{op0} = {op1} + {op2};\n" });
+
+ _actions.push_back(
+ { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
+ {
+ writer.op_ternary(args.at(0), TernaryOp::Clamp, args.at(1), args.at(2), args.at(3));
+ },
+ false,
+ "{op0} = clamp({op1}, {op2}, {op3});\n" });
+
+ _actions.push_back(
+ { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
+ {
+ writer.op_if(args.at(0), BinaryOp::Greater, args.at(1), [] {});
+ },
+ true,
+ "if ({op0} > {op1})\n{\n}\n" });
+ }
+
+ bool run() override
+ {
+ bool all_tests_passed = true;
+ int32_t test_idx = 0;
+
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ for(size_t test_no = 0; test_no < _tests.size(); ++test_no)
+ {
+ const TestInfo &test = _tests.at(test_no);
+
+ // Declare all the tiles and get the full name of those tile operand.
+ std::vector<TileOperand> tiles;
+ std::vector<std::string> expected_tiles_name;
+
+ for(size_t operand_no = 0; operand_no < test.operands.size(); ++operand_no)
+ {
+ const TestOperand &operand = test.operands.at(operand_no);
+ std::string name = "test" + std::to_string(test_no) + "_op" + std::to_string(operand_no);
+
+ const TileOperand full_tile = writer.declare_tile(name, TileInfo(DataType::Fp32, operand.height, operand.width));
+
+ switch(operand.access_type)
+ {
+ case AccessType::Full:
+ tiles.emplace_back(full_tile);
+ break;
+
+ case AccessType::Vector:
+ tiles.emplace_back(full_tile.row(operand.start_row));
+ break;
+
+ case AccessType::Scalar:
+ tiles.emplace_back(full_tile.scalar(operand.start_row, operand.start_col));
+ break;
+
+ case AccessType::ScalarOfVector:
+ tiles.emplace_back(full_tile.row(operand.start_row).scalar(0, operand.start_col));
+ break;
+
+ default:
+ CKW_THROW_MSG("Unsupported access type!");
+ }
+
+ expected_tiles_name.push_back("G0__" + name);
+ }
+
+ // Try each writing action using the newly declared tiles.
+ for(const TestAction &action : _actions)
+ {
+ if(action.scalar_only && //
+ (test.operands.at(0).access_type != AccessType::Scalar && //
+ (test.operands.at(0).height != 1 || test.operands.at(0).width != 1)))
+ {
+ continue;
+ }
+
+ writer.start_capture_code();
+
+ action.write(writer, tiles);
+
+ // The expected code is constructed from the format strings.
+ std::string expected_code = action.expected_code;
+
+ for(size_t operand_no = 0; operand_no < test.operands.size(); ++operand_no)
+ {
+ const TestOperand &operand = test.operands.at(operand_no);
+
+ const std::string op_name = search_and_replace(operand.name, "{tile_name}", expected_tiles_name.at(operand_no));
+ expected_code = search_and_replace(expected_code, "{op" + std::to_string(operand_no) + "}", op_name);
+ }
+
+ VALIDATE_TEST(writer.check_added_code(expected_code), all_tests_passed, test_idx++);
+ }
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string search_and_replace(const std::string &src, const std::string &search, const std::string &replace)
+ {
+ std::string result = src;
+
+ size_t idx = 0;
+
+ while(true)
+ {
+ idx = result.find(search, idx);
+
+ if(idx == std::string::npos)
+ {
+ break;
+ }
+
+ result = result.replace(idx, search.size(), replace);
+ }
+
+ return result;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterSubTileTest";
+ }
+
+private:
+ enum class AccessType
+ {
+ Full,
+ Vector,
+ Scalar,
+ ScalarOfVector,
+ };
+
+ struct TestOperand
+ {
+ int32_t height;
+ int32_t width;
+
+ AccessType access_type;
+ int32_t start_row;
+ int32_t start_col;
+
+ std::string name;
+ };
+
+ struct TestInfo
+ {
+ std::vector<TestOperand> operands;
+ };
+
+ struct TestAction
+ {
+ std::function<void(CLKernelWriter &, const std::vector<TileOperand> &)> write;
+
+ bool scalar_only;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+ std::vector<TestAction> _actions{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H