aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src/KernelWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src/KernelWriter.cpp')
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp82
1 files changed, 45 insertions, 37 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 5c9a16ee33..9f58d9fefa 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -23,9 +23,11 @@
*/
#include "ckw/KernelWriter.h"
+
#include "ckw/Error.h"
#include "ckw/TensorInfo.h"
#include "ckw/TensorOperand.h"
+
#include "src/Prototype.h"
#include <sstream>
@@ -38,7 +40,7 @@ namespace
inline prototype::TensorInfo create_impl_tensor_info(const TensorInfo &info)
{
- return prototype::TensorInfo{ info.shape(), info.data_type(), info.data_layout(), info.id() };
+ return prototype::TensorInfo{info.shape(), info.data_type(), info.data_layout(), info.id()};
}
} // namespace
@@ -86,7 +88,8 @@ int32_t KernelWriter::next_id_space()
// Tensor and tile declaration
// =================================================================================================
-TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
+TensorOperand &
+KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
{
const auto var_name = generate_variable_name(name);
@@ -120,13 +123,11 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope
auto &operand = _kernel->register_operand(std::move(operand_ptr));
const auto &name = operand.name();
- if(!operand.is_constant())
+ if (!operand.is_constant())
{
const auto &info = operand.tile_info();
- _impl->declare_tile(
- name,
- prototype::TileInfo(info.data_type(), info.width(), info.height()));
+ _impl->declare_tile(name, prototype::TileInfo(info.data_type(), info.width(), info.height()));
}
else
{
@@ -140,16 +141,15 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope
// Load and store
// =================================================================================================
-void KernelWriter::op_load(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler, const TileOperand &dilation_y)
+void KernelWriter::op_load(TileOperand &tile,
+ const TensorOperand &tensor,
+ const TensorTileSampler &sampler,
+ const TileOperand &dilation_y)
{
prototype::TensorOperand impl_tensor(
tensor.name(),
- prototype::GpuSampler{
- sampler.format(),
- prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(),
- sampler.address_mode_y(),
- sampler.address_mode_z() });
+ prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
auto impl_x = sampler.x().create_impl_operand(_impl.get());
auto impl_y = sampler.y().create_impl_operand(_impl.get());
@@ -167,12 +167,8 @@ void KernelWriter::op_load_indirect(TileOperand &tile, const TensorOperand &tens
{
prototype::TensorOperand impl_tensor(
tensor.name(),
- prototype::GpuSampler{
- sampler.format(),
- prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(),
- sampler.address_mode_y(),
- sampler.address_mode_z() });
+ prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
auto impl_x = sampler.x().create_impl_operand(_impl.get());
auto impl_y = sampler.y().create_impl_operand(_impl.get());
@@ -194,12 +190,8 @@ void KernelWriter::util_get_indirect_buffer(TileOperand &tile,
{
prototype::TensorOperand impl_tensor(
tensor.name(),
- prototype::GpuSampler{
- sampler.format(),
- prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(),
- sampler.address_mode_y(),
- sampler.address_mode_z() });
+ prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
auto impl_x = x.create_impl_operand(_impl.get());
auto impl_y = y.create_impl_operand(_impl.get());
@@ -215,12 +207,8 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons
{
prototype::TensorOperand impl_tensor(
tensor.name(),
- prototype::GpuSampler{
- sampler.format(),
- prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(),
- sampler.address_mode_y(),
- sampler.address_mode_z() });
+ prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
auto impl_src = tile.create_impl_operand(_impl.get());
auto impl_x = sampler.x().create_impl_operand(_impl.get());
auto impl_y = sampler.y().create_impl_operand(_impl.get());
@@ -250,7 +238,10 @@ void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand
_impl->op_cast_expression(impl_dst, impl_src, policy);
}
-void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
+void KernelWriter::op_binary_expression(const TileOperand &dst,
+ const TileOperand &lhs,
+ BinaryOp op,
+ const TileOperand &rhs)
{
auto impl_lhs = lhs.create_impl_operand(_impl.get());
auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -275,7 +266,10 @@ void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFu
_impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
}
-void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
+void KernelWriter::op_binary_elementwise_function(const TileOperand &dst,
+ BinaryFunction opcode,
+ const TileOperand &first,
+ const TileOperand &second)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_first = first.create_impl_operand(_impl.get());
@@ -284,7 +278,11 @@ void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, Binary
_impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
}
-void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst,
+ TernaryFunction opcode,
+ const TileOperand &first,
+ const TileOperand &second,
+ const TileOperand &third)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_first = first.create_impl_operand(_impl.get());
@@ -305,7 +303,10 @@ void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand
_impl->compound_statement_end();
}
-void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+void KernelWriter::op_else_if(const TileOperand &lhs,
+ BinaryOp op,
+ const TileOperand &rhs,
+ const std::function<void()> &body)
{
auto impl_lhs = lhs.create_impl_operand(_impl.get());
auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -324,14 +325,21 @@ void KernelWriter::op_else(const std::function<void()> &body)
_impl->compound_statement_end();
}
-void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+void KernelWriter::op_for_loop(const TileOperand &var_name,
+ BinaryOp cond_op,
+ const TileOperand &cond_value_name,
+ const TileOperand &update_var_name,
+ AssignmentOp update_op,
+ const TileOperand &update_value_name,
+ const std::function<void()> &body)
{
auto impl_var_name = var_name.create_impl_operand(_impl.get());
auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
auto impl_update_var_name = update_var_name.create_impl_operand(_impl.get());
auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
- _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op, impl_update_value_name);
+ _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op,
+ impl_update_value_name);
_impl->compound_statement_begin();
body();
_impl->compound_statement_end();