aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src/KernelWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src/KernelWriter.cpp')
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp96
1 files changed, 89 insertions, 7 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 5d79985e87..73458efa1d 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -85,7 +85,7 @@ int32_t KernelWriter::next_id_space()
// Tensor and tile declaration
// =================================================================================================
-TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
{
const auto var_name = generate_variable_name(name);
@@ -97,7 +97,7 @@ TensorOperand &KernelWriter::create_tensor_argument(const char *name, const Tens
return *operand;
}
-TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
+TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
{
const auto var_name = generate_variable_name(name);
@@ -107,7 +107,7 @@ TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
return *operand;
}
-std::string KernelWriter::generate_variable_name(const char *name) const
+std::string KernelWriter::generate_variable_name(const std::string &name) const
{
std::stringstream var_name;
@@ -181,7 +181,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons
// Data processing
// =================================================================================================
-void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
+void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_src = src.create_impl_operand(_impl.get());
@@ -189,7 +189,15 @@ void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
_impl->op_assign(impl_dst, impl_src);
}
-void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op)
+void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_src = src.create_impl_operand(_impl.get());
+
+ _impl->op_cast_expression(impl_dst, impl_src, policy);
+}
+
+void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
{
auto impl_lhs = lhs.create_impl_operand(_impl.get());
auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -198,12 +206,81 @@ void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs
_impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
}
-void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode)
+void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_src = src.create_impl_operand(_impl.get());
- _impl->op_scalar_function(impl_dst, impl_src, opcode);
+ _impl->op_unary_expression(impl_dst, op, impl_src);
+}
+
+void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_src = src.create_impl_operand(_impl.get());
+
+ _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
+}
+
+void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_first = first.create_impl_operand(_impl.get());
+ auto impl_second = second.create_impl_operand(_impl.get());
+
+ _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
+}
+
+void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_first = first.create_impl_operand(_impl.get());
+ auto impl_second = second.create_impl_operand(_impl.get());
+ auto impl_third = third.create_impl_operand(_impl.get());
+
+ _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
+}
+
+void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+ auto impl_lhs = lhs.create_impl_operand(_impl.get());
+ auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+ _impl->op_if_header(impl_lhs, op, impl_rhs);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+ auto impl_lhs = lhs.create_impl_operand(_impl.get());
+ auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+ _impl->op_else_if_header(impl_lhs, op, impl_rhs);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else(const std::function<void()> &body)
+{
+ _impl->op_else_header();
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+{
+ auto impl_var_name = var_name.create_impl_operand(_impl.get());
+ auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
+ auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
+
+ _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
}
// =================================================================================================
@@ -215,6 +292,11 @@ void KernelWriter::op_get_global_id(TileOperand &dst, int32_t dim)
_impl->op_get_global_id(prototype::Operand(dst.name()), dim);
}
+void KernelWriter::op_return()
+{
+ _impl->op_return();
+}
+
// =================================================================================================
// Code generation
// =================================================================================================