diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.h')
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.h | 55 |
1 files changed, 54 insertions, 1 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h index 5df148da7b..a40698d7bb 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.h +++ b/compute_kernel_writer/src/cl/CLKernelWriter.h @@ -36,6 +36,11 @@ namespace ckw class CLTile; class CLTensorArgument; +class TensorSampler; +class TileOperand; +class TensorOperand; + +enum class MemoryOperation; /** OpenCL kernel writer. */ class CLKernelWriter : public KernelWriter @@ -76,9 +81,43 @@ public: /** Declare a tile given name and tile information * * Similar to @ref KernelWriter::declare_tile() - */ + */ TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) override; + // ============================================================================================= + // Memory Operations + // ============================================================================================= + + /** Load the data from the tensor memory to the tile using the sampling information. + * + * Similar to @ref KernelWriter::op_load() + */ + void op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; + + /** Load the data from the tensor memory to the tile in a dilated way using the sampling information. + * + * Similar to @ref KernelWriter::op_load_dilated() + */ + void op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, + const TileOperand &dilation_x, const TileOperand &dilation_y) override; + + /** Store the data to the tensor memory from the tile using the sampling information. + * + * Similar to @ref KernelWriter::op_store() + */ + void op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; + + /** Store the data to the tensor memory from the tile in a dilated way using the sampling information. + * + * Similar to @ref KernelWriter::op_store_dilated() + */ + void op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, + const TileOperand &dilation_x, const TileOperand &dilation_y) override; + protected: /** Append the specified code to the kernel body source code. */ template <typename T, typename... TArgs> @@ -98,6 +137,20 @@ protected: /** Get the current kernel body source code. */ const std::string &body_source_code() const; +// For helper functions +private: + /** Return @ref CLTile object from the @ref TileOperand object. + * + * This function performs appropriate check before doing type casting. + */ + const CLTile &to_cl_tile(const TileOperand &operand); + + /** Helper function to consolidate all load/store logic in this class */ + void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, + const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, + const CLTile &dilation_x, const CLTile &dilation_y); + +// For attributes private: /** This string contains the kernel body source code, not the full CL source code. * The full source code will only be generated when the user calls @ref KernelWriter::emit_kernel. |