diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h')
-rw-r--r-- | compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h b/compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h index 008d147fa2..7f363431e8 100644 --- a/compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h +++ b/compute_kernel_writer/src/cl/helpers/ICLMemoryOpHelper.h @@ -25,7 +25,11 @@ #ifndef CKW_SRC_CL_HELPERS_ICLMEMORYOPHELPER_H #define CKW_SRC_CL_HELPERS_ICLMEMORYOPHELPER_H +#include "ckw/TensorSampler.h" +#include "src/Tensor3dMapper.h" + #include <cstdint> +#include <memory> #include <string> namespace ckw @@ -34,7 +38,8 @@ namespace ckw // Forward Declarations class CLTile; class CLKernelWriter; -class Tensor3dMapper; +class ITensor; +class TensorSampler; enum class MemoryOperation; /** Base class OpenCL memory operation helper classes @@ -45,13 +50,15 @@ class ICLMemoryOpHelper public: /** Constructor * - * @param[in] x @ref ckw::CLKernelWriter object to write the code - * @param[in] mapper @ref ckw::Tensor3dMapper object that tells how to map the Nd tensor to 3d - * @param[in] op The memory operation to be done (e.g. Load/Store) + * @param[in] writer @ref ckw::CLKernelWriter object to write the code + * @param[in] tensor @ref ckw::ITensor object to perform the memory operation on + * @param[in] sampler @ref ckw::TensorSampler object that tells how to sample a tensor + * @param[in] op The memory operation to be done (e.g. Load/Store) */ - ICLMemoryOpHelper(CLKernelWriter *x, Tensor3dMapper *mapper, MemoryOperation op) - : _writer(x), _mapper(mapper), _op(op) + ICLMemoryOpHelper(CLKernelWriter *writer, ITensor *tensor, TensorSampler *sampler, MemoryOperation op) + : _writer(writer), _tensor(tensor), _sampler(sampler), _op(op) { + _mapper = std::make_unique<Tensor3dMapper>(tensor, sampler->format()); } /** Copy constructor */ @@ -72,7 +79,7 @@ public: * @param[in] z tile object that describes the z-coordinate of the tensor involved * @param[in] b tile object that describes the batch offset of the tensor involved */ - virtual void initialize(CLTile *dst, CLTile *x, CLTile *z, CLTile *b) = 0; + virtual void initialize(const CLTile *dst, const CLTile *x, const CLTile *z, const CLTile *b) = 0; /** Method that writes the actual code to the writer that performs the mentioned memory * operation on the tile initialized. It writes the code for a specific row given in the @@ -81,7 +88,7 @@ public: * @param[in] row_id row id * @param[in] coord_y y-coordinate as string */ - virtual void write(int32_t row_id, const std::string &coord_y) = 0; + virtual void write_row(int32_t row_id, const std::string &coord_y) = 0; /** Method that finalizes the code in the writer object. This part is usually for taking * care of finalizing anything that's been initialized inside @ref IMemoryHelper::initialize() @@ -91,14 +98,16 @@ public: virtual void finalize() = 0; protected: - CLKernelWriter *_writer; - Tensor3dMapper *_mapper; - MemoryOperation _op; - CLTile *_dst{ nullptr }; - int32_t _ls_width_full{ 0 }; - std::string _coord_x{}; - std::string _coord_z{}; - std::string _coord_b{}; + CLKernelWriter *_writer{ nullptr }; + ITensor *_tensor{ nullptr }; + TensorSampler *_sampler{ nullptr }; + MemoryOperation _op; + std::unique_ptr<Tensor3dMapper> _mapper{ nullptr }; + const CLTile *_dst{ nullptr }; + int32_t _ls_width_full{ 0 }; + std::string _coord_x{}; + std::string _coord_z{}; + std::string _coord_b{}; }; } // namespace ckw |