aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp')
-rw-r--r--compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp43
1 files changed, 25 insertions, 18 deletions
diff --git a/compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp b/compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp
index a98ebed8fa..7d16f35fbe 100644
--- a/compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp
+++ b/compute_kernel_writer/src/cl/helpers/CLMemoryOpBufferHelper.cpp
@@ -34,15 +34,16 @@
#include "src/cl/CLTile.h"
#include "src/ITensor.h"
#include "src/Tensor3dMapper.h"
+#include "src/TileView.h"
namespace ckw
{
-bool CLMemoryOpBufferHelper::validate(const CLKernelWriter *writer,
- const ITensor *tensor,
- const TensorSampler *sampler,
- const Tensor3dMapper *mapper,
- MemoryOperation op,
- const CLTile *dst)
+bool CLMemoryOpBufferHelper::validate(const CLKernelWriter *writer,
+ const ITensor *tensor,
+ const TensorSampler *sampler,
+ const Tensor3dMapper *mapper,
+ MemoryOperation op,
+ const TileView<CLTile> &dst)
{
CKW_UNUSED(writer, tensor, mapper, op, dst);
@@ -100,17 +101,14 @@ bool CLMemoryOpBufferHelper::validate(const CLKernelWriter *writer,
* The outermost block is x, then z and then y. This is why, if/else's covering for y are initialized
* at each row write. In some addressing modes, such as None, no if/else conditions are written.
*/
-void CLMemoryOpBufferHelper::initialize(const CLTile *dst, const CLTile *x, const CLTile *z, const CLTile *b)
+void CLMemoryOpBufferHelper::initialize(const CLTile *x, const CLTile *z, const CLTile *b)
{
- _dst = dst;
-
CKW_ASSERT(validate(_writer, _tensor, _sampler, _mapper.get(), _op, _dst));
- _ls_width_full = dst->info().width();
- _coord_x = x->scalar(0, 0).str;
- _coord_z = z->scalar(0, 0).str;
- _coord_b = b->scalar(0, 0).str;
- _coord_orig_z = _coord_z;
+ _coord_x = x->scalar(0, 0).str;
+ _coord_z = z->scalar(0, 0).str;
+ _coord_b = b->scalar(0, 0).str;
+ _coord_orig_z = _coord_z;
out_of_bound_initialize_x(_coord_x);
out_of_bound_initialize_z(_coord_z);
@@ -121,7 +119,7 @@ void CLMemoryOpBufferHelper::write_row(int32_t row_id, const std::string &coord_
// The only check required is on Y.
out_of_bound_initialize_y(coord_y);
- const std::string dst = _dst->vector(row_id).str;
+ const std::string dst = _dst.vector(row_id).str;
const std::string address = to_buffer_address(_coord_x, coord_y, _coord_z, _coord_b);
const std::string ls_buf = to_statement(_op, _ls_width_full, dst, address);
@@ -133,10 +131,17 @@ void CLMemoryOpBufferHelper::write_row(int32_t row_id, const std::string &coord_
// The left over load/store will be written in the finalize stage
if (_ls_width_part.size() != 0)
{
- int32_t col_start = 0;
+ int32_t col_start = 0;
+ const TileArea original_area = _dst.area();
+
for (int32_t partial_width : _ls_width_part)
{
- const std::string dst = _dst->vector(row_id, col_start, partial_width).str;
+ // Set the active area
+ const TileArea area(original_area.row_start(), original_area.row_end(), col_start,
+ col_start + partial_width);
+ _dst.area(area);
+
+ const std::string dst = _dst.vector(row_id).str;
const std::string coord_x = _coord_x + " + " + std::to_string(col_start);
const std::string address = to_buffer_address(coord_x, coord_y, _coord_z, _coord_b);
const std::string statement = to_statement(_op, partial_width, dst, address);
@@ -144,6 +149,8 @@ void CLMemoryOpBufferHelper::write_row(int32_t row_id, const std::string &coord_
col_start += partial_width;
}
+ // Restore the original area
+ _dst.area(original_area);
}
}
@@ -304,7 +311,7 @@ std::string CLMemoryOpBufferHelper::to_buffer_address(const std::string &x,
CKW_ASSERT(tensor_storage == TensorStorageType::BufferUint8Ptr);
const std::string ptr_buf = _tensor->storage(tensor_storage).val;
- const std::string dst_type = cl_data_type_rounded_up_to_valid_vector_width(_dst->info().data_type(), 1);
+ const std::string dst_type = cl_data_type_rounded_up_to_valid_vector_width(_dst.data_type(), 1);
std::string address;
address += "(__global ";