diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src/KernelWriter.cpp')
-rw-r--r-- | compute_kernel_writer/prototype/src/KernelWriter.cpp | 96 |
1 files changed, 89 insertions, 7 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp index 5d79985e87..73458efa1d 100644 --- a/compute_kernel_writer/prototype/src/KernelWriter.cpp +++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp @@ -85,7 +85,7 @@ int32_t KernelWriter::next_id_space() // Tensor and tile declaration // ================================================================================================= -TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info) +TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info) { const auto var_name = generate_variable_name(name); @@ -97,7 +97,7 @@ TensorOperand &KernelWriter::create_tensor_argument(const char *name, const Tens return *operand; } -TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value) +TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value) { const auto var_name = generate_variable_name(name); @@ -107,7 +107,7 @@ TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value) return *operand; } -std::string KernelWriter::generate_variable_name(const char *name) const +std::string KernelWriter::generate_variable_name(const std::string &name) const { std::stringstream var_name; @@ -181,7 +181,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons // Data processing // ================================================================================================= -void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src) +void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_src = src.create_impl_operand(_impl.get()); @@ -189,7 +189,15 @@ void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src) _impl->op_assign(impl_dst, impl_src); } -void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op) +void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_src = src.create_impl_operand(_impl.get()); + + _impl->op_cast_expression(impl_dst, impl_src, policy); +} + +void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs) { auto impl_lhs = lhs.create_impl_operand(_impl.get()); auto impl_rhs = rhs.create_impl_operand(_impl.get()); @@ -198,12 +206,81 @@ void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs); } -void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode) +void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_src = src.create_impl_operand(_impl.get()); - _impl->op_scalar_function(impl_dst, impl_src, opcode); + _impl->op_unary_expression(impl_dst, op, impl_src); +} + +void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_src = src.create_impl_operand(_impl.get()); + + _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src); +} + +void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_first = first.create_impl_operand(_impl.get()); + auto impl_second = second.create_impl_operand(_impl.get()); + + _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second); +} + +void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_first = first.create_impl_operand(_impl.get()); + auto impl_second = second.create_impl_operand(_impl.get()); + auto impl_third = third.create_impl_operand(_impl.get()); + + _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third); +} + +void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +{ + auto impl_lhs = lhs.create_impl_operand(_impl.get()); + auto impl_rhs = rhs.create_impl_operand(_impl.get()); + + _impl->op_if_header(impl_lhs, op, impl_rhs); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +{ + auto impl_lhs = lhs.create_impl_operand(_impl.get()); + auto impl_rhs = rhs.create_impl_operand(_impl.get()); + + _impl->op_else_if_header(impl_lhs, op, impl_rhs); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_else(const std::function<void()> &body) +{ + _impl->op_else_header(); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body) +{ + auto impl_var_name = var_name.create_impl_operand(_impl.get()); + auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get()); + auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get()); + + _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); } // ================================================================================================= @@ -215,6 +292,11 @@ void KernelWriter::op_get_global_id(TileOperand &dst, int32_t dim) _impl->op_get_global_id(prototype::Operand(dst.name()), dim); } +void KernelWriter::op_return() +{ + _impl->op_return(); +} + // ================================================================================================= // Code generation // ================================================================================================= |