diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src/KernelWriter.cpp')
-rw-r--r-- | compute_kernel_writer/prototype/src/KernelWriter.cpp | 82 |
1 files changed, 45 insertions, 37 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp index 5c9a16ee33..9f58d9fefa 100644 --- a/compute_kernel_writer/prototype/src/KernelWriter.cpp +++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp @@ -23,9 +23,11 @@ */ #include "ckw/KernelWriter.h" + #include "ckw/Error.h" #include "ckw/TensorInfo.h" #include "ckw/TensorOperand.h" + #include "src/Prototype.h" #include <sstream> @@ -38,7 +40,7 @@ namespace inline prototype::TensorInfo create_impl_tensor_info(const TensorInfo &info) { - return prototype::TensorInfo{ info.shape(), info.data_type(), info.data_layout(), info.id() }; + return prototype::TensorInfo{info.shape(), info.data_type(), info.data_layout(), info.id()}; } } // namespace @@ -86,7 +88,8 @@ int32_t KernelWriter::next_id_space() // Tensor and tile declaration // ================================================================================================= -TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type) +TensorOperand & +KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type) { const auto var_name = generate_variable_name(name); @@ -120,13 +123,11 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope auto &operand = _kernel->register_operand(std::move(operand_ptr)); const auto &name = operand.name(); - if(!operand.is_constant()) + if (!operand.is_constant()) { const auto &info = operand.tile_info(); - _impl->declare_tile( - name, - prototype::TileInfo(info.data_type(), info.width(), info.height())); + _impl->declare_tile(name, prototype::TileInfo(info.data_type(), info.width(), info.height())); } else { @@ -140,16 +141,15 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope // Load and store // ================================================================================================= -void KernelWriter::op_load(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler, const TileOperand &dilation_y) +void KernelWriter::op_load(TileOperand &tile, + const TensorOperand &tensor, + const TensorTileSampler &sampler, + const TileOperand &dilation_y) { prototype::TensorOperand impl_tensor( tensor.name(), - prototype::GpuSampler{ - sampler.format(), - prototype::to_gpu_tensor_storage(tensor.storage_type()), - sampler.address_mode_x(), - sampler.address_mode_y(), - sampler.address_mode_z() }); + prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()), + sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()}); auto impl_x = sampler.x().create_impl_operand(_impl.get()); auto impl_y = sampler.y().create_impl_operand(_impl.get()); @@ -167,12 +167,8 @@ void KernelWriter::op_load_indirect(TileOperand &tile, const TensorOperand &tens { prototype::TensorOperand impl_tensor( tensor.name(), - prototype::GpuSampler{ - sampler.format(), - prototype::to_gpu_tensor_storage(tensor.storage_type()), - sampler.address_mode_x(), - sampler.address_mode_y(), - sampler.address_mode_z() }); + prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()), + sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()}); auto impl_x = sampler.x().create_impl_operand(_impl.get()); auto impl_y = sampler.y().create_impl_operand(_impl.get()); @@ -194,12 +190,8 @@ void KernelWriter::util_get_indirect_buffer(TileOperand &tile, { prototype::TensorOperand impl_tensor( tensor.name(), - prototype::GpuSampler{ - sampler.format(), - prototype::to_gpu_tensor_storage(tensor.storage_type()), - sampler.address_mode_x(), - sampler.address_mode_y(), - sampler.address_mode_z() }); + prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()), + sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()}); auto impl_x = x.create_impl_operand(_impl.get()); auto impl_y = y.create_impl_operand(_impl.get()); @@ -215,12 +207,8 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons { prototype::TensorOperand impl_tensor( tensor.name(), - prototype::GpuSampler{ - sampler.format(), - prototype::to_gpu_tensor_storage(tensor.storage_type()), - sampler.address_mode_x(), - sampler.address_mode_y(), - sampler.address_mode_z() }); + prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()), + sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()}); auto impl_src = tile.create_impl_operand(_impl.get()); auto impl_x = sampler.x().create_impl_operand(_impl.get()); auto impl_y = sampler.y().create_impl_operand(_impl.get()); @@ -250,7 +238,10 @@ void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand _impl->op_cast_expression(impl_dst, impl_src, policy); } -void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs) +void KernelWriter::op_binary_expression(const TileOperand &dst, + const TileOperand &lhs, + BinaryOp op, + const TileOperand &rhs) { auto impl_lhs = lhs.create_impl_operand(_impl.get()); auto impl_rhs = rhs.create_impl_operand(_impl.get()); @@ -275,7 +266,10 @@ void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFu _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src); } -void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second) +void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, + BinaryFunction opcode, + const TileOperand &first, + const TileOperand &second) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_first = first.create_impl_operand(_impl.get()); @@ -284,7 +278,11 @@ void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, Binary _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second); } -void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third) +void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, + TernaryFunction opcode, + const TileOperand &first, + const TileOperand &second, + const TileOperand &third) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_first = first.create_impl_operand(_impl.get()); @@ -305,7 +303,10 @@ void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand _impl->compound_statement_end(); } -void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +void KernelWriter::op_else_if(const TileOperand &lhs, + BinaryOp op, + const TileOperand &rhs, + const std::function<void()> &body) { auto impl_lhs = lhs.create_impl_operand(_impl.get()); auto impl_rhs = rhs.create_impl_operand(_impl.get()); @@ -324,14 +325,21 @@ void KernelWriter::op_else(const std::function<void()> &body) _impl->compound_statement_end(); } -void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body) +void KernelWriter::op_for_loop(const TileOperand &var_name, + BinaryOp cond_op, + const TileOperand &cond_value_name, + const TileOperand &update_var_name, + AssignmentOp update_op, + const TileOperand &update_value_name, + const std::function<void()> &body) { auto impl_var_name = var_name.create_impl_operand(_impl.get()); auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get()); auto impl_update_var_name = update_var_name.create_impl_operand(_impl.get()); auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get()); - _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op, impl_update_value_name); + _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op, + impl_update_value_name); _impl->compound_statement_begin(); body(); _impl->compound_statement_end(); |