diff options
Diffstat (limited to 'compute_kernel_writer/include/ckw/KernelWriter.h')
-rw-r--r-- | compute_kernel_writer/include/ckw/KernelWriter.h | 85 |
1 files changed, 59 insertions, 26 deletions
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h index 15c99fe652..0d739e859a 100644 --- a/compute_kernel_writer/include/ckw/KernelWriter.h +++ b/compute_kernel_writer/include/ckw/KernelWriter.h @@ -115,7 +115,8 @@ public: * @param[in] first The first source tile. * @param[in] second The second source tile. */ - virtual void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0; + virtual void + op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0; /** Write ternary expression statement: `<dst> = <op>(<first>, <second>, <third>);`. * @@ -125,7 +126,11 @@ public: * @param[in] second The second source tile. * @param[in] third The third source tile. */ - virtual void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) = 0; + virtual void op_ternary(const TileOperand &dst, + TernaryOp op, + const TileOperand &first, + const TileOperand &second, + const TileOperand &third) = 0; // ============================================================================================= // Flow control @@ -138,7 +143,8 @@ public: * @param[in] rhs The RHS tile of the condition. * @param[in] body The function that writes the body of the if block. */ - virtual void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0; + virtual void + op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0; /** Write else-if block: `else if(<lhs> <op> <rhs>) { <body> }`. * @@ -147,7 +153,8 @@ public: * @param[in] rhs The RHS tile of the condition. * @param[in] body The function that writes the body of the else-if block. */ - virtual void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0; + virtual void + op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0; /** Write an else block: `else { <body> }`. * @@ -165,10 +172,13 @@ public: * @param[in] update_value The value which is updated at every iteration. * @param[in] body The function that writes the body of the for-loop block. */ - virtual void op_for_loop( - const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value, - const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value, - const std::function<void()> &body) = 0; + virtual void op_for_loop(const TileOperand &var, + BinaryOp cond_op, + const TileOperand &cond_value, + const TileOperand &update_var, + AssignmentOp update_op, + const TileOperand &update_value, + const std::function<void()> &body) = 0; /** Write the return statement. */ virtual void op_return() = 0; @@ -271,9 +281,13 @@ public: * @param[in] z z-coordinate * @param[in] batch batch */ - virtual void op_load( - const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0; + virtual void op_load(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) = 0; /** Load the data from the tensor memory to the tile in a dilated way using the sampling information. * @@ -282,27 +296,41 @@ public: * @param[in] dilation_x Dilation while reading in x-dimension * @param[in] dilation_y Dilation while reading in y-dimension */ - virtual void op_load_dilated( - const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) = 0; + virtual void op_load_dilated(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) = 0; /** Store the data to the tensor memory from the tile using the sampling information. * * Similar to @ref KernelWriter::op_load() */ - virtual void op_store( - const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0; + virtual void op_store(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch) = 0; /** Store the data to the tensor memory from the tile in a dilated way using the sampling information. * * Similar to @ref KernelWriter::op_load_dilated() */ - virtual void op_store_dilated( - const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, - const TileOperand &dilation_x, const TileOperand &dilation_y) = 0; + virtual void op_store_dilated(const TensorOperand &tensor_op, + const TileOperand &tile_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch, + const TileOperand &dilation_x, + const TileOperand &dilation_y) = 0; /** Load the data from the tensor memory to the tile using the indirect buffer approach and respecting the sampling information. * @@ -314,8 +342,13 @@ public: * @param[in] z z-coordinate * @param[in] batch batch */ - virtual void op_load_indirect(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, - const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch_op) = 0; + virtual void op_load_indirect(const TileOperand &tile_op, + const TensorOperand &tensor_op, + TensorSampler &sampler, + const TileOperand &x, + const TileOperand &y, + const TileOperand &z, + const TileOperand &batch_op) = 0; protected: // ============================================================================================= @@ -373,8 +406,8 @@ protected: static DataType get_data_type(const ConstantData &data); private: - int32_t _id_space{ 0 }; - int32_t _last_created_id_space{ 0 }; + int32_t _id_space{0}; + int32_t _last_created_id_space{0}; }; } // namespace ckw |