diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-07-24 15:47:34 +0100 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-07-25 10:08:05 +0000 |
commit | 0250fa6c2a0bdbf88c1264f32ad0a1a4e3fec3f3 (patch) | |
tree | f5772b7df56c30f8ce8bd7fc182ba803c44b1691 /compute_kernel_writer/src/cl/CLTile.cpp | |
parent | 25d26f4d86042e0ca52ac1bef4039b187f77b5b3 (diff) | |
download | ComputeLibrary-0250fa6c2a0bdbf88c1264f32ad0a1a4e3fec3f3.tar.gz |
Use CLTile for both variable and constant tiles
* It's easier to reuse CLTile for other things for example
tensor component if it can represent both variable
and constant tiles.
Partially resolves: COMPMID-6391
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: Ief06f670332cb339bd31b94a31b4bec186e1f1b8
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9966
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/src/cl/CLTile.cpp')
-rw-r--r-- | compute_kernel_writer/src/cl/CLTile.cpp | 158 |
1 files changed, 126 insertions, 32 deletions
diff --git a/compute_kernel_writer/src/cl/CLTile.cpp b/compute_kernel_writer/src/cl/CLTile.cpp index cb0b22a23b..c6cf47d831 100644 --- a/compute_kernel_writer/src/cl/CLTile.cpp +++ b/compute_kernel_writer/src/cl/CLTile.cpp @@ -34,6 +34,7 @@ namespace ckw { CLTile::CLTile(const std::string &name, const TileInfo &info) + : _is_constant(false) { validate_tile_info(info); @@ -41,25 +42,66 @@ CLTile::CLTile(const std::string &name, const TileInfo &info) _info = info; } +CLTile::CLTile(const TileContainer &vals, DataType dt) + : _is_constant(true) +{ + const int32_t w = vals[0].size(); + const int32_t h = vals.size(); + + _info.width(w); + _info.height(h); + _info.data_type(dt); + + validate_tile_info(_info); + + _vals = TileContainer(h, std::vector<std::string>(w)); + + for(int32_t y = 0; y < h; ++y) + { + for(int32_t x = 0; x < w; ++x) + { + _vals[y][x] = vals[y][x]; + } + } +} + +const std::string &CLTile::name() const +{ + return _basename; +} + +const TileInfo &CLTile::info() const +{ + return _info; +} + TileVariable CLTile::scalar(int32_t row, int32_t col) const { // Clamp to nearest valid edge col = clamp(col, static_cast<int32_t>(0), _info.width() - 1); row = clamp(row, static_cast<int32_t>(0), _info.height() - 1); - TileVariable t; - t.str = create_var_name(row); - t.desc.dt = _info.data_type(); - t.desc.len = 1; - - // This check is required because if the width has only one element, we cannot use .s0 - if(_info.width() != 1) + if(_is_constant) { - // Automatic broadcasting - t.str += ".s" + dec_to_hex_as_string(col); + // We can use the vector method to retrieve the scalar variable stored in the constant tile + return vector(row, col, 1); } + else + { + TileVariable t; + t.str = create_var_name(row); + t.desc.dt = _info.data_type(); + t.desc.len = 1; - return t; + // This check is required because if the width has only one element, we cannot use .s0 + if(_info.width() != 1) + { + // Automatic broadcasting + t.str += ".s" + dec_to_hex_as_string(col); + } + + return t; + } } TileVariable CLTile::vector(int32_t row) const @@ -67,11 +109,18 @@ TileVariable CLTile::vector(int32_t row) const // Clamp to nearest valid edge row = clamp(row, static_cast<int32_t>(0), _info.height() - 1); - TileVariable t; - t.str = create_var_name(row); - t.desc.dt = _info.data_type(); - t.desc.len = _info.width(); - return t; + if(_is_constant) + { + return vector(row, 0, _info.width()); + } + else + { + TileVariable t; + t.str = create_var_name(row); + t.desc.dt = _info.data_type(); + t.desc.len = _info.width(); + return t; + } } TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const @@ -83,38 +132,75 @@ TileVariable CLTile::vector(int32_t row, int32_t col_start, int32_t width) const row = clamp(row, static_cast<int32_t>(0), _info.height() - 1); TileVariable t; - t.str = create_var_name(row); t.desc.dt = _info.data_type(); t.desc.len = width; - if(_info.width() != 1) + if(_is_constant) + { + // The vector has the following form: ((data_typeN)(val0, val1,..., ValN-1)) + t.str = "((" + cl_get_variable_datatype_as_string(t.desc.dt, width) + ")"; + t.str += "("; + + int32_t col = col_start; + for(; col < width - 1; ++col) + { + t.str += _vals[row][col]; + t.str += ", "; + } + t.str += _vals[row][col]; + t.str += "))"; + } + else { - t.str += ".s"; - for(int i = 0; i < width; ++i) + t.str = create_var_name(row); + + if(_info.width() != 1) { - t.str += dec_to_hex_as_string(col_start + i); + t.str += ".s"; + for(int i = 0; i < width; ++i) + { + t.str += dec_to_hex_as_string(col_start + i); + } } } + return t; } std::vector<TileVariable> CLTile::all() const { std::vector<TileVariable> vars; - for(int32_t y = 0; y < _info.height(); ++y) + + if(_is_constant) { - TileVariable t; - t.str = create_var_name(y); - t.desc.dt = _info.data_type(); - t.desc.len = _info.width(); - vars.push_back(t); + for(int32_t y = 0; y < _info.height(); ++y) + { + for(int32_t x = 0; x < _info.width(); ++x) + { + // We can use the vector method to retrieve all the scalar variables stored in the constant tile + TileVariable t = vector(y, x, 1); + vars.push_back(t); + } + } } + else + { + for(int32_t y = 0; y < _info.height(); ++y) + { + TileVariable t; + t.str = create_var_name(y); + t.desc.dt = _info.data_type(); + t.desc.len = _info.width(); + vars.push_back(t); + } + } + return vars; } bool CLTile::is_assignable() const { - return true; + return !_is_constant; } std::string CLTile::create_var_name(int32_t row) const @@ -122,11 +208,7 @@ std::string CLTile::create_var_name(int32_t row) const std::string var_name = _basename; // If a scalar variable, we do not append the row index - if(_info.height() == 1) - { - return var_name; - } - else + if(_info.height() > 1) { var_name += "_"; var_name += std::to_string(row); @@ -134,4 +216,16 @@ std::string CLTile::create_var_name(int32_t row) const return var_name; } + +std::vector<int32_t> CLTile::supported_vector_lengths() const +{ + return std::vector<int32_t>{ 1, 2, 3, 4, 8, 16 }; +} + +void CLTile::validate_tile_info(const TileInfo &info) const +{ + CKW_ASSERT_MSG(cl_validate_vector_length(info.width()), "Unsupported TileInfo width"); + CKW_ASSERT_MSG(info.data_type() != DataType::Unknown, "DataType::Unknown is not supported"); +} + } // namespace ckw
\ No newline at end of file |