diff options
Diffstat (limited to 'compute_kernel_writer/src')
-rw-r--r-- | compute_kernel_writer/src/KernelWriter.cpp | 7 | ||||
-rw-r--r-- | compute_kernel_writer/src/TileOperand.cpp | 38 | ||||
-rw-r--r-- | compute_kernel_writer/src/TileView.cpp | 57 | ||||
-rw-r--r-- | compute_kernel_writer/src/TileView.h | 189 | ||||
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.cpp | 229 | ||||
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.h | 11 | ||||
-rw-r--r-- | compute_kernel_writer/src/cl/CLTile.cpp | 5 |
7 files changed, 414 insertions, 122 deletions
diff --git a/compute_kernel_writer/src/KernelWriter.cpp b/compute_kernel_writer/src/KernelWriter.cpp index fb0e62c8ce..0bea1200b7 100644 --- a/compute_kernel_writer/src/KernelWriter.cpp +++ b/compute_kernel_writer/src/KernelWriter.cpp @@ -27,10 +27,13 @@ #include "ckw/TileOperand.h" #include "ckw/types/TargetArchitecture.h" #include "ckw/types/TargetLanguage.h" +#include "src/TileView.h" #include "src/cl/CLKernelWriter.h" #include "src/cl/CLTensorArgument.h" #include "src/cl/CLTile.h" +#include <tuple> + namespace ckw { @@ -90,9 +93,9 @@ TileOperand KernelWriter::create_tile_operand(ITile &tile) return TileOperand(tile); } -ITile &KernelWriter::get_tile(const TileOperand &operand) +std::tuple<ITile &, TileArea> KernelWriter::get_tile(const TileOperand &operand) { - return operand._tile; + return { *operand._tile, { operand._row_start, operand._row_end, operand._col_start, operand._col_end } }; } TensorOperand KernelWriter::create_tensor_operand(ITensor &tensor) diff --git a/compute_kernel_writer/src/TileOperand.cpp b/compute_kernel_writer/src/TileOperand.cpp index 7d180feec8..3dfa2b8b2b 100644 --- a/compute_kernel_writer/src/TileOperand.cpp +++ b/compute_kernel_writer/src/TileOperand.cpp @@ -23,13 +23,49 @@ */ #include "ckw/TileOperand.h" +#include "ckw/Error.h" +#include "src/ITile.h" namespace ckw { TileOperand::TileOperand(ITile &tile) - : _tile(tile) + : _tile(&tile), _row_start(0), _row_end(tile.info().height()), _col_start(0), _col_end(tile.info().width()) { } +TileOperand::TileOperand(const TileOperand &operand, int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end) + : _tile(operand._tile), _row_start(row_start), _row_end(row_end), _col_start(col_start), _col_end(col_end) +{ + CKW_ASSERT(row_start >= 0 && row_start < _tile->info().height()); + CKW_ASSERT(row_end > row_start && row_end <= _tile->info().height()); + CKW_ASSERT(col_start >= 0 && col_start < _tile->info().width()); + CKW_ASSERT(col_end > col_start && col_end <= _tile->info().width()); +} + +TileOperand TileOperand::tile(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end) const +{ + CKW_ASSERT(row_start >= 0 && _row_start + row_start < _row_end); + CKW_ASSERT(row_end > row_start && _row_start + row_end <= _row_end); + CKW_ASSERT(col_start >= 0 && _col_start + col_start < _col_end); + CKW_ASSERT(col_end > col_start && _col_start + col_end <= _col_end); + + return TileOperand(*this, _row_start + row_start, _row_start + row_end, _col_start + col_start, _col_start + col_end); +} + +TileOperand TileOperand::row(int32_t row) const +{ + CKW_ASSERT(row >= 0 && _row_start + row < _row_end); + + return tile(_row_start + row, _row_start + row + 1, _col_start, _col_end); +} + +TileOperand TileOperand::scalar(int32_t row, int32_t col) const +{ + CKW_ASSERT(row >= 0 && _row_start + row < _row_end); + CKW_ASSERT(col >= 0 && _col_start + col < _col_end); + + return tile(_row_start + row, _row_start + row + 1, _col_start + col, _col_start + col + 1); +} + } // namespace ckw diff --git a/compute_kernel_writer/src/TileView.cpp b/compute_kernel_writer/src/TileView.cpp new file mode 100644 index 0000000000..ea803f92f4 --- /dev/null +++ b/compute_kernel_writer/src/TileView.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/TileView.h" + +#include <cstdint> + +namespace ckw +{ + +TileArea::TileArea(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end) + : _row_start(row_start), _row_end(row_end), _col_start(col_start), _col_end(col_end) +{ +} + +int32_t TileArea::row_start() const +{ + return _row_start; +} + +int32_t TileArea::row_end() const +{ + return _row_end; +} + +int32_t TileArea::col_start() const +{ + return _col_start; +} + +int32_t TileArea::col_end() const +{ + return _col_end; +} + +} // namespace ckw diff --git a/compute_kernel_writer/src/TileView.h b/compute_kernel_writer/src/TileView.h new file mode 100644 index 0000000000..e0d034fa8d --- /dev/null +++ b/compute_kernel_writer/src/TileView.h @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef CKW_SRC_TILEVIEW_H +#define CKW_SRC_TILEVIEW_H + +#include "ckw/Error.h" +#include "ckw/types/DataType.h" +#include "src/ITile.h" + +#include <cstdint> + +namespace ckw +{ + +/** A rectangular active area of a tile. */ +class TileArea +{ +public: + /** Create a new tile rectangular active area. + * + * The range of rows and columns is defined by pairs of start and end indices, inclusive lower and exclusive upper. + * In other word, any row and column indices satisfied the following conditions will be part of the active area: + * + * row_start <= row_index < row_end + * col_start <= col_index < col_end + * + * @param[in] row_start The start index of the row range. + * @param[in] row_end The end index of the row range. + * @param[in] col_start The start index of the column range. + * @param[in] col_end The end index of the column range. + */ + TileArea(int32_t row_start, int32_t row_end, int32_t col_start, int32_t col_end); + + /** Get the start row index. */ + int32_t row_start() const; + + /** Get the end row (exclusive) index. */ + int32_t row_end() const; + + /** Get the start column index. */ + int32_t col_start() const; + + /** Get the end column (exclusive) index. */ + int32_t col_end() const; + +private: + int32_t _row_start; + int32_t _row_end; + int32_t _col_start; + int32_t _col_end; +}; + +/** A rectangular view of a tile. */ +template <typename T> +class TileView +{ +public: + /** Create a tile view that refers to the whole tile. + * + * @param[in] tile The tile object. + */ + TileView(const T &tile) + : _tile(&tile), _area(0, tile.info().height(), 0, tile.info().width()) + { + } + + /** Create a new rectangular view of the given tile. + * + * @param[in] tile The tile object. + * @param[in] area The rectangular active area. + */ + TileView(const T &tile, const TileArea &area) + : _tile(&tile), _area(area) + { + } + + /** Get the tile object. + * + * The caller must guarantee that the tile view refers to the whole tile. + */ + const T &full_tile() const + { + CKW_ASSERT(is_full_tile()); + + return *_tile; + } + + /** Get the data type of the tile. */ + DataType data_type() const + { + return _tile->info().data_type(); + } + + /** Get the start row index. */ + int32_t row_start() const + { + return _area.row_start(); + } + + /** Get the end row index. */ + int32_t row_end() const + { + return _area.row_end(); + } + + /** Get the start column index. */ + int32_t col_start() const + { + return _area.col_start(); + } + + /** Get the end column index. */ + int32_t col_end() const + { + return _area.col_end(); + } + + /** Get the height of the tile view. */ + int32_t height() const + { + return _area.row_end() - _area.row_start(); + } + + /** Get the width of the tile view. */ + int32_t width() const + { + return _area.col_end() - _area.col_start(); + } + + /** See @ref IVectorAccess::vector. */ + TileVariable vector(int32_t row) const + { + return _tile->vector(row_start() + row, col_start(), width()); + } + + /** See @ref IScalarAccess::scalar. */ + TileVariable scalar(int32_t row, int32_t col) const + { + return _tile->scalar(row_start() + row, col_start() + col); + } + + /** Get the name of the tile. */ + const std::string &name() const + { + return _tile->name(); + } + + /** Get whether the tile view is a scalar element. */ + bool is_scalar() const + { + return height() == 1 && width() == 1; + } + + /** Get whether the tile view refers to the whole tile. */ + bool is_full_tile() const + { + return row_start() == 0 && row_end() == _tile->info().height() && col_start() == 0 && col_end() == _tile->info().width(); + } + +private: + const T *_tile; + TileArea _area; +}; + +} // namespace ckw + +#endif // CKW_SRC_TILEVIEW_H diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp index 4074da7912..2db9c139b7 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp +++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp @@ -32,6 +32,7 @@ #include "ckw/types/MemoryOperation.h" #include "ckw/types/TargetLanguage.h" #include "src/ITensorComponent.h" +#include "src/TileView.h" #include "src/cl/CLHelpers.h" #include "src/cl/CLTensorArgument.h" #include "src/cl/CLTile.h" @@ -42,6 +43,7 @@ #include <algorithm> #include <cstdint> +#include <tuple> #include <vector> namespace ckw @@ -112,39 +114,39 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name) void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src) { - const auto &dst_tile = to_cl_tile(dst); - const auto &src_tile = to_cl_tile(src); + const auto dst_view = to_cl_tile_view(dst); + const auto src_view = to_cl_tile_view(src); - const auto dst_w = dst_tile.info().width(); - const auto dst_h = dst_tile.info().height(); - const auto src_w = src_tile.info().width(); + const auto dst_w = dst_view.width(); + const auto dst_h = dst_view.height(); + const auto src_w = src_view.width(); - const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w); + const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w); const auto broadcast_src_x = dst_w != 1 && src_w == 1; const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : ""; - CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match."); - CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). for(int32_t y = 0; y < dst_h; ++y) { - append_code(dst_tile.vector(y).str, " = ", src_prefix, src_tile.vector(y).str, ";\n"); + append_code(dst_view.vector(y).str, " = ", src_prefix, src_view.vector(y).str, ";\n"); } } void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) { - const auto &dst_tile = to_cl_tile(dst); - const auto &src_tile = to_cl_tile(src); + const auto dst_view = to_cl_tile_view(dst); + const auto src_view = to_cl_tile_view(src); - const auto dst_w = dst_tile.info().width(); - const auto dst_h = dst_tile.info().height(); - const auto src_w = src_tile.info().width(); + const auto dst_w = dst_view.width(); + const auto dst_h = dst_view.height(); + const auto src_w = src_view.width(); - const auto dst_type = dst_tile.info().data_type(); + const auto dst_type = dst_view.data_type(); const auto convert_type_str = cl_get_variable_datatype_as_string(dst_type, src_w); const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w); @@ -155,27 +157,27 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con const auto broadcast_x = dst_w != 1 && src_w == 1; const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : ""; - CKW_ASSERT_MSG(src_tile.info().data_type() != dst_tile.info().data_type(), "Source and destination type must be different."); - CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.data_type() != dst_view.data_type(), "Source and destination type must be different."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). for(int32_t y = 0; y < dst_h; ++y) { - append_code(dst_tile.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_tile.vector(y).str, ");\n"); + append_code(dst_view.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_view.vector(y).str, ");\n"); } } void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) { - const auto &dst_tile = to_cl_tile(dst); - const auto &src_tile = to_cl_tile(src); + const auto dst_view = to_cl_tile_view(dst); + const auto src_view = to_cl_tile_view(src); - const auto dst_w = dst_tile.info().width(); - const auto dst_h = dst_tile.info().height(); - const auto src_w = src_tile.info().width(); + const auto dst_w = dst_view.width(); + const auto dst_h = dst_view.height(); + const auto src_w = src_view.width(); - const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w); + const auto data_type_str = cl_get_variable_datatype_as_string(dst_view.data_type(), dst_w); const auto broadcast_src_x = dst_w != 1 && src_w == 1; const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : ""; @@ -186,35 +188,34 @@ void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOper const auto op_prefix = op_is_func ? op_name + "(" : op_name; const auto op_suffix = op_is_func ? ")" : ""; - CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match."); - CKW_ASSERT_MSG(!is_data_type_float(src_tile.info().data_type()), "Logical and bitwise not only work with integer."); - CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(src_view.data_type() == dst_view.data_type(), "Source and destination type must match."); + CKW_ASSERT_MSG(src_view.height() == dst_h || src_view.height() == 1, "Tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension."); // Broadcasting on y dimension is automatic (see CLTile::vector). for(int32_t y = 0; y < dst_h; ++y) { - append_code(dst_tile.vector(y).str, " = ", src_prefix, op_prefix, src_tile.vector(y).str, op_suffix, ";\n"); + append_code(dst_view.vector(y).str, " = ", src_prefix, op_prefix, src_view.vector(y).str, op_suffix, ";\n"); } } void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) { - const auto &dst_tile = to_cl_tile(dst); - const auto &lhs_tile = to_cl_tile(first); - const auto &rhs_tile = to_cl_tile(second); + const auto dst_view = to_cl_tile_view(dst); + const auto lhs_view = to_cl_tile_view(first); + const auto rhs_view = to_cl_tile_view(second); - const auto dst_w = dst_tile.info().width(); - const auto dst_h = dst_tile.info().height(); - const auto lhs_w = lhs_tile.info().width(); - const auto rhs_w = rhs_tile.info().width(); + const auto dst_w = dst_view.width(); + const auto dst_h = dst_view.height(); + const auto lhs_w = lhs_view.width(); + const auto rhs_w = rhs_view.width(); - const auto data_type = lhs_tile.info().data_type(); + const auto data_type = lhs_view.data_type(); - CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match."); + CKW_ASSERT_MSG(lhs_view.data_type() == rhs_view.data_type(), "LHS and RHS type must match."); - CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(lhs_view.height() == dst_h || lhs_view.height() == 1, "LHS tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(rhs_view.height() == dst_h || rhs_view.height() == 1, "RHS tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension."); CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension."); @@ -230,10 +231,10 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp for(int32_t k = 0; k < lhs_w; ++k) { append_code( - dst_tile.scalar(x, y).str, " = fma(", - lhs_tile.scalar(k, y).str, ", ", - rhs_tile.scalar(k, x).str, ", ", - dst_tile.scalar(x, y).str, ");\n"); + dst_view.scalar(x, y).str, " = fma(", + lhs_view.scalar(k, y).str, ", ", + rhs_view.scalar(k, x).str, ", ", + dst_view.scalar(x, y).str, ");\n"); } } } @@ -259,25 +260,25 @@ void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOp // Broadcasting on y dimension is automatic (see CLTile::vector). for(int32_t y = 0; y < dst_h; ++y) { - append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix); + append_code(dst_view.vector(y).str, op_prefix, lhs_prefix, lhs_view.vector(y).str, op_separator, rhs_prefix, rhs_view.vector(y).str, op_suffix); } } } void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) { - const auto &dst_tile = to_cl_tile(dst); - const auto &first_tile = to_cl_tile(first); - const auto &second_tile = to_cl_tile(second); - const auto &third_tile = to_cl_tile(third); - - const auto dst_w = dst_tile.info().width(); - const auto dst_h = dst_tile.info().height(); - const auto first_w = first_tile.info().width(); - const auto second_w = second_tile.info().width(); - const auto third_w = third_tile.info().width(); - - const auto data_type = dst_tile.info().data_type(); + const auto dst_view = to_cl_tile_view(dst); + const auto first_view = to_cl_tile_view(first); + const auto second_view = to_cl_tile_view(second); + const auto third_view = to_cl_tile_view(third); + + const auto dst_w = dst_view.width(); + const auto dst_h = dst_view.height(); + const auto first_w = first_view.width(); + const auto second_w = second_view.width(); + const auto third_w = third_view.width(); + + const auto data_type = dst_view.data_type(); const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w); const auto op_info = cl_get_ternary_op(op); @@ -293,12 +294,12 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : ""; CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function."); - CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match."); - CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match."); + CKW_ASSERT_MSG(second_view.data_type() == dst_view.data_type(), "2nd source and destination type must match."); + CKW_ASSERT_MSG(third_view.data_type() == dst_view.data_type(), "3rd source and destination type must match."); - CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension."); - CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(first_view.height() == dst_h || first_view.height() == 1, "1st tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(second_view.height() == dst_h || second_view.height() == 1, "2nd tile height must match or source is broadcasting in y dimension."); + CKW_ASSERT_MSG(third_view.height() == dst_h || third_view.height() == 1, "3rd tile height must match or source is broadcasting in y dimension."); CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension."); CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension."); @@ -308,30 +309,30 @@ void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const Tile for(int32_t y = 0; y < dst_h; ++y) { append_code( - dst_tile.vector(y).str, " = ", op_name, "(", - first_prefix, first_tile.vector(y).str, ", ", - second_prefix, second_tile.vector(y).str, ", ", - third_prefix, third_tile.vector(y).str, ");\n"); + dst_view.vector(y).str, " = ", op_name, "(", + first_prefix, first_view.vector(y).str, ", ", + second_prefix, second_view.vector(y).str, ", ", + third_prefix, third_view.vector(y).str, ");\n"); } } void CLKernelWriter::op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if) { - const auto &lhs_tile = to_cl_tile(lhs); - const auto &rhs_tile = to_cl_tile(rhs); + const auto lhs_view = to_cl_tile_view(lhs); + const auto rhs_view = to_cl_tile_view(rhs); - const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_tile.info().data_type())); + const auto op_name = std::get<1>(cl_get_binary_op(op, lhs_view.data_type())); CKW_ASSERT(op == BinaryOp::Less || op == BinaryOp::LessEqual || op == BinaryOp::Equal || op == BinaryOp::GreaterEqual || op == BinaryOp::Greater); - CKW_ASSERT(lhs_tile.is_scalar()); - CKW_ASSERT(rhs_tile.is_scalar()); + CKW_ASSERT(lhs_view.is_scalar()); + CKW_ASSERT(rhs_view.is_scalar()); if(is_else_if) { append_code("else "); } - append_code("if (", lhs_tile.scalar(0, 0).str, " ", op_name, " ", rhs_tile.scalar(0, 0).str, ")\n{\n"); + append_code("if (", lhs_view.scalar(0, 0).str, " ", op_name, " ", rhs_view.scalar(0, 0).str, ")\n{\n"); write_body(body); append_code("}\n"); } @@ -358,25 +359,25 @@ void CLKernelWriter::op_for_loop( const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value, const std::function<void()> &body) { - const auto &var_tile = to_cl_tile(var); - const auto &cond_value_tile = to_cl_tile(cond_value); - const auto &update_var_tile = to_cl_tile(update_var); - const auto &update_value_tile = to_cl_tile(update_value); + const auto var_view = to_cl_tile_view(var); + const auto cond_value_view = to_cl_tile_view(cond_value); + const auto update_var_view = to_cl_tile_view(update_var); + const auto update_value_view = to_cl_tile_view(update_value); - CKW_ASSERT(var_tile.is_scalar()); - CKW_ASSERT(cond_value_tile.is_scalar()); - CKW_ASSERT(update_var_tile.is_scalar()); - CKW_ASSERT(update_value_tile.is_scalar()); + CKW_ASSERT(var_view.is_scalar()); + CKW_ASSERT(cond_value_view.is_scalar()); + CKW_ASSERT(update_var_view.is_scalar()); + CKW_ASSERT(update_value_view.is_scalar()); - CKW_ASSERT(var_tile.info().data_type() == cond_value_tile.info().data_type()); - CKW_ASSERT(update_var_tile.info().data_type() == update_value_tile.info().data_type()); + CKW_ASSERT(var_view.data_type() == cond_value_view.data_type()); + CKW_ASSERT(update_var_view.data_type() == update_value_view.data_type()); - const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_tile.info().data_type())); + const auto cond_op_name = std::get<1>(cl_get_binary_op(cond_op, var_view.data_type())); CKW_ASSERT(cond_op == BinaryOp::Less || cond_op == BinaryOp::LessEqual || cond_op == BinaryOp::Equal || cond_op == BinaryOp::GreaterEqual || cond_op == BinaryOp::Greater); append_code( - "for (; ", var_tile.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_tile.scalar(0, 0).str, "; ", - update_var_tile.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_tile.scalar(0, 0).str, ")\n{\n"); + "for (; ", var_view.scalar(0, 0).str, " ", cond_op_name, " ", cond_value_view.scalar(0, 0).str, "; ", + update_var_view.scalar(0, 0).str, " ", cl_get_assignment_op_as_string(update_op), " ", update_value_view.scalar(0, 0).str, ")\n{\n"); write_body(body); append_code("}\n"); } @@ -388,14 +389,14 @@ void CLKernelWriter::op_return() void CLKernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim) { - const auto &tile = to_cl_tile(dst); + const auto tile_view = to_cl_tile_view(dst); - CKW_ASSERT(tile.is_scalar()); - CKW_ASSERT(tile.info().data_type() == DataType::Int32 || tile.info().data_type() == DataType::Uint32); + CKW_ASSERT(tile_view.is_scalar()); + CKW_ASSERT(tile_view.data_type() == DataType::Int32 || tile_view.data_type() == DataType::Uint32); CKW_ASSERT(dim >= 0 && dim <= 2); - append_code(tile.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n"); + append_code(tile_view.scalar(0, 0).str, " = get_global_id(", std::to_string(dim), ");\n"); } void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileOperand> &operands) @@ -405,13 +406,12 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO for(auto &op : operands) { - const auto &tile = to_cl_tile(op); - const auto &info = tile.info(); + const auto tile_view = to_cl_tile_view(op); - const auto &name = tile.name(); - const auto width = info.width(); - const auto height = info.height(); - const auto data_type = info.data_type(); + const auto name = tile_view.name(); + const auto width = tile_view.width(); + const auto height = tile_view.height(); + const auto data_type = tile_view.data_type(); // Construct the format specifier to print out one row of the tile. std::string row_format("%"); @@ -479,7 +479,7 @@ void CLKernelWriter::op_print(const std::string &prefix, const std::vector<TileO // Construct the variable arguments for the printf statement. for(int32_t row = 0; row < height; ++row) { - args_code += ", " + tile.vector(row).str; + args_code += ", " + tile_view.vector(row).str; } } @@ -566,9 +566,11 @@ void CLKernelWriter::op_write_raw_code(const std::string &raw_code) append_code(raw_code); } -const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const +TileView<CLTile> CLKernelWriter::to_cl_tile_view(const TileOperand &operand) const { - const auto &tile = get_tile(operand); + const auto tile_and_area = get_tile(operand); + ITile &tile = std::get<0>(tile_and_area); + const TileArea area = std::get<1>(tile_and_area); #ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED // Check if the tile is a CLTile created by this kernel writer. @@ -620,7 +622,7 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const } #endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED - return static_cast<const CLTile &>(tile); + return { static_cast<CLTile &>(tile), area }; } void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, @@ -636,10 +638,10 @@ void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOpe const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, const TileOperand &dilation_x, const TileOperand &dilation_y) { - const auto &dil_x_tile = to_cl_tile(dilation_x); - const auto &dil_y_tile = to_cl_tile(dilation_y); + const auto dil_x_view = to_cl_tile_view(dilation_x); + const auto dil_y_view = to_cl_tile_view(dilation_y); - op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */); + op_load_store(MemoryOperation::Load, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */); } void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, @@ -655,14 +657,14 @@ void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const Tile const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, const TileOperand &dilation_x, const TileOperand &dilation_y) { - const auto &dil_x_tile = to_cl_tile(dilation_x); - const auto &dil_y_tile = to_cl_tile(dilation_y); + const auto dil_x_view = to_cl_tile_view(dilation_x); + const auto dil_y_view = to_cl_tile_view(dilation_y); - op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_tile, dil_y_tile, false /* indirect buffer */); + op_load_store(MemoryOperation::Store, tile_op, tensor_op, sampler, x, y, z, batch, dil_x_view, dil_y_view, false /* indirect buffer */); } void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) { const CLTile dilation_x({ { "1" } }, DataType::Int32); const CLTile dilation_y({ { "1" } }, DataType::Int32); @@ -671,8 +673,8 @@ void CLKernelWriter::op_load_indirect(const TileOperand &tile_op, const TensorOp } void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const CLTile &dilation_x, const CLTile &dilation_y, bool indirect_buffer) + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, + const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer) { CKW_UNUSED(dilation_x); CKW_ASSERT(dilation_x.is_scalar()); @@ -681,7 +683,7 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o if(indirect_buffer) { - CKW_ASSERT(dilation_y.scalar(0,0).str == "((int)(1))" && dilation_x.scalar(0,0).str == "((int)(1))"); + CKW_ASSERT(dilation_y.scalar(0, 0).str == "((int)(1))" && dilation_x.scalar(0, 0).str == "((int)(1))"); } ITensor &tensor = get_tensor(tensor_op); @@ -700,11 +702,12 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o CKW_THROW_MSG("Unsupported tensor storage"); } - const auto &tile = to_cl_tile(tile_op); - const auto &x_tile = to_cl_tile(x); - const auto &y_tile = to_cl_tile(y); - const auto &z_tile = to_cl_tile(z); - const auto &batch_tile = to_cl_tile(batch); + // Load/store op doesn't support sub-tile access. + const auto tile = to_cl_tile_view(tile_op).full_tile(); + const auto x_tile = to_cl_tile_view(x).full_tile(); + const auto y_tile = to_cl_tile_view(y).full_tile(); + const auto z_tile = to_cl_tile_view(z).full_tile(); + const auto batch_tile = to_cl_tile_view(batch).full_tile(); CKW_ASSERT(x_tile.is_scalar()); CKW_ASSERT(z_tile.is_scalar()); diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h index 1e2e5dc910..d7cf24d5e6 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.h +++ b/compute_kernel_writer/src/cl/CLKernelWriter.h @@ -26,6 +26,7 @@ #define CKW_SRC_CL_CLKERNELWRITER_H #include "ckw/KernelWriter.h" +#include "src/TileView.h" #include <memory> #include <set> @@ -150,14 +151,14 @@ public: const TileOperand &dilation_x, const TileOperand &dilation_y) override; void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; protected: - /** Return @ref CLTile object from the @ref TileOperand object. + /** Return a tile view containing a reference to @ref CLTile object and the active area. * * This function performs appropriate check before doing type casting. */ - const CLTile &to_cl_tile(const TileOperand &operand) const; + TileView<CLTile> to_cl_tile_view(const TileOperand &operand) const; /** Append the specified code to the kernel body source code. */ template <typename T, typename... TArgs> @@ -181,8 +182,8 @@ protected: private: /** Helper method to consolidate all load/store logic in this class */ void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const CLTile &dilation_x, const CLTile &dilation_y, bool indirect_buffer); + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, + const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer); /** This function is the generic function to write both `if` and `else if` blocks. * diff --git a/compute_kernel_writer/src/cl/CLTile.cpp b/compute_kernel_writer/src/cl/CLTile.cpp index 556db0f47b..0cce69a9e1 100644 --- a/compute_kernel_writer/src/cl/CLTile.cpp +++ b/compute_kernel_writer/src/cl/CLTile.cpp @@ -125,6 +125,9 @@ TileVariable CLTile::vector(int32_t row) const TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const { + CKW_ASSERT(col_start >= 0 && col_start < _info.width()); + CKW_ASSERT(col_start + width <= _info.width()); + // Validate the new vector length cl_validate_vector_length(width); @@ -154,7 +157,7 @@ TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const { t.str = create_var_name(row); - if(_info.width() != 1) + if(_info.width() != 1 && _info.width() != width) { t.str += ".s"; for(int i = 0; i < width; ++i) |