diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.h')
-rw-r--r-- | compute_kernel_writer/src/cl/CLKernelWriter.h | 104 |
1 files changed, 74 insertions, 30 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h index d7cf24d5e6..6485bae512 100644 --- a/compute_kernel_writer/src/cl/CLKernelWriter.h +++ b/compute_kernel_writer/src/cl/CLKernelWriter.h @@ -26,6 +26,7 @@ #define CKW_SRC_CL_CLKERNELWRITER_H #include "ckw/KernelWriter.h" + #include "src/TileView.h" #include <memory> @@ -73,7 +74,11 @@ public: void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) override; - void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) override; + void op_ternary(const TileOperand &dst, + TernaryOp op, + const TileOperand &first, + const TileOperand &second, + const TileOperand &third) override; // ============================================================================================= // Flow control @@ -81,14 +86,18 @@ public: void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override; - void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override; + void + op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override; void op_else(const std::function<void()> &body) override; - void op_for_loop( - const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value, - const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value, - const std::function<void()> &body) override; + void op_for_loop(const TileOperand &var, + BinaryOp cond_op, + const TileOperand &cond_value, + const TileOperand &update_var, + AssignmentOp update_op, + const TileOperand &update_value, + const std::function<void()> &body) override; void op_return() override; @@ -132,26 +141,49 @@ public: // Memory Operations // ============================================================================================= - void op_load( - const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; - - void op_load_dilated( - const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) override; - - void op_store( - const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; - - void op_store_dilated( - const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) override; - - void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override; + void op_load(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) override; + + void op_load_dilated(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) override; + + void op_store(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) override; + + void op_store_dilated(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) override; + + void op_load_indirect(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) override; protected: /** Return a tile view containing a reference to @ref CLTile object and the active area. @@ -181,9 +213,17 @@ protected: // For helper functions private: /** Helper method to consolidate all load/store logic in this class */ - void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer); + void op_load_store(MemoryOperation op, + const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileView<CLTile> &dilation_x, + const TileView<CLTile> &dilation_y, + bool indirect_buffer); /** This function is the generic function to write both `if` and `else if` blocks. * @@ -195,7 +235,11 @@ private: * @param[in] body The function that writes the body of the else-if block. * @param[in] is_else_if True if this is an `else if` block, otherwise this is an `if` block. */ - void op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if); + void op_if_generic(const TileOperand &lhs, + BinaryOp op, + const TileOperand &rhs, + const std::function<void()> &body, + bool is_else_if); // For attributes private: |