aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLKernelWriter.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLKernelWriter.h')
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h104
1 files changed, 74 insertions, 30 deletions
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index d7cf24d5e6..6485bae512 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -26,6 +26,7 @@
#define CKW_SRC_CL_CLKERNELWRITER_H
#include "ckw/KernelWriter.h"
+
#include "src/TileView.h"
#include <memory>
@@ -73,7 +74,11 @@ public:
void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) override;
- void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) override;
+ void op_ternary(const TileOperand &dst,
+ TernaryOp op,
+ const TileOperand &first,
+ const TileOperand &second,
+ const TileOperand &third) override;
// =============================================================================================
// Flow control
@@ -81,14 +86,18 @@ public:
void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override;
- void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override;
+ void
+ op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override;
void op_else(const std::function<void()> &body) override;
- void op_for_loop(
- const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value,
- const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
- const std::function<void()> &body) override;
+ void op_for_loop(const TileOperand &var,
+ BinaryOp cond_op,
+ const TileOperand &cond_value,
+ const TileOperand &update_var,
+ AssignmentOp update_op,
+ const TileOperand &update_value,
+ const std::function<void()> &body) override;
void op_return() override;
@@ -132,26 +141,49 @@ public:
// Memory Operations
// =============================================================================================
- void op_load(
- const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
-
- void op_load_dilated(
- const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y) override;
-
- void op_store(
- const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
-
- void op_store_dilated(
- const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y) override;
-
- void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
+ void op_load(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch) override;
+
+ void op_load_dilated(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileOperand &dilation_x,
+ const TileOperand &dilation_y) override;
+
+ void op_store(const TensorOperand &tensor_op,
+ const TileOperand &tile_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch) override;
+
+ void op_store_dilated(const TensorOperand &tensor_op,
+ const TileOperand &tile_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileOperand &dilation_x,
+ const TileOperand &dilation_y) override;
+
+ void op_load_indirect(const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch) override;
protected:
/** Return a tile view containing a reference to @ref CLTile object and the active area.
@@ -181,9 +213,17 @@ protected:
// For helper functions
private:
/** Helper method to consolidate all load/store logic in this class */
- void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileView<CLTile> &dilation_x, const TileView<CLTile> &dilation_y, bool indirect_buffer);
+ void op_load_store(MemoryOperation op,
+ const TileOperand &tile_op,
+ const TensorOperand &tensor_op,
+ TensorSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &z,
+ const TileOperand &batch,
+ const TileView<CLTile> &dilation_x,
+ const TileView<CLTile> &dilation_y,
+ bool indirect_buffer);
/** This function is the generic function to write both `if` and `else if` blocks.
*
@@ -195,7 +235,11 @@ private:
* @param[in] body The function that writes the body of the else-if block.
* @param[in] is_else_if True if this is an `else if` block, otherwise this is an `if` block.
*/
- void op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if);
+ void op_if_generic(const TileOperand &lhs,
+ BinaryOp op,
+ const TileOperand &rhs,
+ const std::function<void()> &body,
+ bool is_else_if);
// For attributes
private: