diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src')
4 files changed, 372 insertions, 117 deletions
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp index bbf5c440a7..692d504887 100644 --- a/compute_kernel_writer/prototype/src/Kernel.cpp +++ b/compute_kernel_writer/prototype/src/Kernel.cpp @@ -23,7 +23,7 @@ */ #include "ckw/Kernel.h" -#include "ckw/Types.h" +#include "ckw/types/GpuTargetLanguage.h" #include "src/Prototype.h" namespace ckw diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp index 5d79985e87..73458efa1d 100644 --- a/compute_kernel_writer/prototype/src/KernelWriter.cpp +++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp @@ -85,7 +85,7 @@ int32_t KernelWriter::next_id_space() // Tensor and tile declaration // ================================================================================================= -TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info) +TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info) { const auto var_name = generate_variable_name(name); @@ -97,7 +97,7 @@ TensorOperand &KernelWriter::create_tensor_argument(const char *name, const Tens return *operand; } -TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value) +TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value) { const auto var_name = generate_variable_name(name); @@ -107,7 +107,7 @@ TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value) return *operand; } -std::string KernelWriter::generate_variable_name(const char *name) const +std::string KernelWriter::generate_variable_name(const std::string &name) const { std::stringstream var_name; @@ -181,7 +181,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons // Data processing // ================================================================================================= -void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src) +void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_src = src.create_impl_operand(_impl.get()); @@ -189,7 +189,15 @@ void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src) _impl->op_assign(impl_dst, impl_src); } -void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op) +void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_src = src.create_impl_operand(_impl.get()); + + _impl->op_cast_expression(impl_dst, impl_src, policy); +} + +void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs) { auto impl_lhs = lhs.create_impl_operand(_impl.get()); auto impl_rhs = rhs.create_impl_operand(_impl.get()); @@ -198,12 +206,81 @@ void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs); } -void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode) +void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src) { auto impl_dst = dst.create_impl_operand(_impl.get()); auto impl_src = src.create_impl_operand(_impl.get()); - _impl->op_scalar_function(impl_dst, impl_src, opcode); + _impl->op_unary_expression(impl_dst, op, impl_src); +} + +void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_src = src.create_impl_operand(_impl.get()); + + _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src); +} + +void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_first = first.create_impl_operand(_impl.get()); + auto impl_second = second.create_impl_operand(_impl.get()); + + _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second); +} + +void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third) +{ + auto impl_dst = dst.create_impl_operand(_impl.get()); + auto impl_first = first.create_impl_operand(_impl.get()); + auto impl_second = second.create_impl_operand(_impl.get()); + auto impl_third = third.create_impl_operand(_impl.get()); + + _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third); +} + +void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +{ + auto impl_lhs = lhs.create_impl_operand(_impl.get()); + auto impl_rhs = rhs.create_impl_operand(_impl.get()); + + _impl->op_if_header(impl_lhs, op, impl_rhs); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) +{ + auto impl_lhs = lhs.create_impl_operand(_impl.get()); + auto impl_rhs = rhs.create_impl_operand(_impl.get()); + + _impl->op_else_if_header(impl_lhs, op, impl_rhs); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_else(const std::function<void()> &body) +{ + _impl->op_else_header(); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); +} + +void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body) +{ + auto impl_var_name = var_name.create_impl_operand(_impl.get()); + auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get()); + auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get()); + + _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name); + _impl->compound_statement_begin(); + body(); + _impl->compound_statement_end(); } // ================================================================================================= @@ -215,6 +292,11 @@ void KernelWriter::op_get_global_id(TileOperand &dst, int32_t dim) _impl->op_get_global_id(prototype::Operand(dst.name()), dim); } +void KernelWriter::op_return() +{ + _impl->op_return(); +} + // ================================================================================================= // Code generation // ================================================================================================= diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h index fdb4ab1bab..18f284b2b1 100644 --- a/compute_kernel_writer/prototype/src/Prototype.h +++ b/compute_kernel_writer/prototype/src/Prototype.h @@ -31,6 +31,7 @@ #include <chrono> #include <cmath> #include <cstdint> // int32_t +#include <functional> #include <iostream> // cout (to be removed) #include <map> #include <memory> @@ -41,7 +42,12 @@ #include "ckw/Error.h" #include "ckw/TensorInfo.h" -#include "ckw/Types.h" +#include "ckw/types/ConvertPolicy.h" +#include "ckw/types/DataType.h" +#include "ckw/types/Functions.h" +#include "ckw/types/GpuTargetLanguage.h" +#include "ckw/types/Operators.h" +#include "ckw/types/TensorSamplerTypes.h" namespace ckw { @@ -1548,6 +1554,18 @@ inline std::string to_string(AssignmentOp op) } } +inline std::string to_string(UnaryOp op) +{ + switch(op) + { + case UnaryOp::LogicalNot: + return "!"; + default: + assert(false); + return ""; + } +} + inline std::string to_string(BinaryOp op) { switch(op) @@ -1576,8 +1594,6 @@ inline std::string to_string(BinaryOp op) return "&&"; case BinaryOp::LogicalOr: return "||"; - case BinaryOp::LogicalNot: - return "!"; default: assert(false); return ""; @@ -2407,12 +2423,6 @@ struct GpuKernelWriterAttribute bool return_tensor_component_by_value{ false }; }; -enum class ConvertPolicy -{ - Wrap, /**< Wrap around */ - Saturate /**< Saturate */ -}; - enum class RoundingMode { None, @@ -2445,36 +2455,44 @@ public: virtual void compound_statement_end() = 0; // Operations - virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0; + virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0; + + virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0; - virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0; + virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0; - virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0; + virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0; - virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0; + virtual void op_unary_expression(const Operand &dst, UnaryOp op, const Operand &src) = 0; - virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0; + virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0; - virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0; + virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0; - virtual void op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) = 0; + virtual void op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) = 0; - virtual void op_if(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0; + virtual void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) = 0; - virtual void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0; + virtual void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) = 0; - virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0; + virtual void op_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0; + + virtual void op_else_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0; + + virtual void op_else_header() = 0; + + virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0; + + virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0; virtual void op_load_immediate(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32), const Operand &dilation_y = Operand("1", OperandType::ScalarInt32)) = 0; - virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0; + virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0; - virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0; + virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0; - virtual void op_return() = 0; + virtual void op_return() = 0; - // virtual void op_else() = 0; - // virtual void op_elseif() = 0; // Utils // It is the process of converting virtual void util_get_indirect_buffer(const Operand &dst, const TensorOperand &tensor, const Operand &x, @@ -2929,10 +2947,10 @@ private: std::string to_ls_buffer_address(const std::string &x, const std::string &y, const std::string &z, const std::string &b) const { - auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage); + auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage); assert(tensor_storage == GpuTensorStorage::BufferUint8Ptr); - const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage); - const std::string dst_type = get_cl_data_type(_dst->format().dt, 1); + const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage); + const std::string dst_type = get_cl_data_type(_dst->format().dt, 1); std::string address; address += "(__global "; @@ -3135,7 +3153,6 @@ private: auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage); const std::string image2d_obj = _mapper.tensor_argument()->storage(tensor_storage); - // const DataType dt = _dst->format().dt; const std::string post_fix = _dst->format().dt == DataType::Fp32 ? "f" : "h"; switch(type) @@ -3242,7 +3259,7 @@ public: }; // This utility method needs to go in utils.h -inline bool is_tile_scalar(IVectorTile *x) +inline bool is_tile_scalar(const IVectorTile *x) { return x->format().w == 1 && x->format().h == 1; } @@ -3415,11 +3432,11 @@ public: void op_get_global_batch(const Operand &o_dst, const TensorOperand &o_tensor) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto dst = operands.unpack(o_dst); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *dst = operands.unpack(o_dst); TensorOperandUnpacker tensor_operands(_data->arguments); - auto tensor = tensor_operands.unpack(o_tensor); + IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor); auto gpu_sampler = o_tensor.sampler(); GpuTensor3dMapper mapper(tensor, gpu_sampler); @@ -3450,13 +3467,39 @@ public: _data->code += ");\n"; } + void op_unary_expression(const Operand &dst_name, UnaryOp op, const Operand &src_name) override + { + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *src = operands.unpack(src_name); + const IVectorTile *dst = operands.unpack(dst_name); + + const int32_t dst_w = dst->format().w; + const int32_t dst_h = dst->format().h; + const int32_t src_w = src->format().w; + const std::string dt = dst->underlying_source_variables()[0].type.str; + + const bool broadcast_src_x = dst_w != 1 && src_w == 1; + + const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : ""; + + // Broadcasting on Y is automatic + for(int32_t y = 0; y < dst_h; ++y) + { + _data->code += dst->vector(y).str; + _data->code += " = "; + _data->code += to_string(op); + _data->code += src_prefix + src->vector(y).str; + _data->code += ";\n"; + } + } + void op_binary_expression(const Operand &dst_name, const Operand &lhs_name, BinaryOp op, const Operand &rhs_name) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto lhs = operands.unpack(lhs_name); - auto rhs = operands.unpack(rhs_name); - auto dst = operands.unpack(dst_name); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *lhs = operands.unpack(lhs_name); + const IVectorTile *rhs = operands.unpack(rhs_name); + const IVectorTile *dst = operands.unpack(dst_name); const int32_t dst_w = dst->format().w; const int32_t dst_h = dst->format().h; @@ -3488,12 +3531,12 @@ public: return; } - bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1; - bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1; + const bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1; + const bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1; - std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : ""; - std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : ""; - std::string op_str = to_string(op); + const std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : ""; + const std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : ""; + const std::string op_str = to_string(op); // Broadcasting on Y is automatic for(int32_t y = 0; y < dst_h; ++y) @@ -3511,21 +3554,20 @@ public: void op_cast_expression(const Operand &o_dst, const Operand &o_src, ConvertPolicy policy) override { - CKW_UNUSED(policy); - - OperandUnpacker operands(_data->tiles, _data->arguments); - auto src = operands.unpack(o_src); - auto dst = operands.unpack(o_dst); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *src = operands.unpack(o_src); + const IVectorTile *dst = operands.unpack(o_dst); // const int32_t dst_w = dst->format().w; const int32_t dst_h = dst->format().h; - const std::string dt = dst->scalar(0, 0).type.str; + const std::string dt = dst->underlying_source_variables()[0].type.str; + const std::string sat = (policy == ConvertPolicy::Saturate ? "_sat" : ""); // Broadcasting on Y is automatic for(int32_t y = 0; y < dst_h; ++y) { _data->code += dst->vector(y).str; - _data->code += " = convert_" + dt + "("; + _data->code += " = convert_" + dt + sat + "("; _data->code += src->vector(y).str; _data->code += ");\n"; } @@ -3533,19 +3575,18 @@ public: void op_assign(const Operand &dst_name, const Operand &src_name) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto src = operands.unpack(src_name); - auto dst = operands.unpack(dst_name); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *src = operands.unpack(src_name); + const IVectorTile *dst = operands.unpack(dst_name); - const int32_t dst_w = dst->format().w; - const int32_t dst_h = dst->format().h; - const int32_t src_w = src->format().w; - // const int32_t src_h = src->format().h; - const std::string dt = dst->scalar(0, 0).type.str; + const int32_t dst_w = dst->format().w; + const int32_t dst_h = dst->format().h; + const int32_t src_w = src->format().w; + const std::string dt = dst->underlying_source_variables()[0].type.str; - bool broadcast_src_x = dst_w != 1 && src_w == 1; + const bool broadcast_src_x = dst_w != 1 && src_w == 1; - std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : ""; + const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : ""; // Broadcasting on Y is automatic for(int32_t y = 0; y < dst_h; ++y) @@ -3558,21 +3599,20 @@ public: } void - op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) override + op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto src = operands.unpack(src_name); - auto dst = operands.unpack(dst_name); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *src = operands.unpack(src_name); + const IVectorTile *dst = operands.unpack(dst_name); - const int32_t dst_w = dst->format().w; - const int32_t dst_h = dst->format().h; - const int32_t src_w = src->format().w; - // const int32_t src_h = src->format().h; - const std::string dt = dst->scalar(0, 0).type.str; + const int32_t dst_w = dst->format().w; + const int32_t dst_h = dst->format().h; + const int32_t src_w = src->format().w; + const std::string dt = dst->underlying_source_variables()[0].type.str; - bool broadcast_src_x = dst_w != 1 && src_w == 1; + const bool broadcast_src_x = dst_w != 1 && src_w == 1; - std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : ""; + const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : ""; // Broadcasting on Y is automatic for(int32_t y = 0; y < dst_h; ++y) @@ -3582,12 +3622,35 @@ public: switch(func) { - case ScalarUnaryFunction::Exp: + case UnaryFunction::Exp: _data->code += "exp("; break; - + case UnaryFunction::Tanh: + _data->code += "tanh("; + break; + case UnaryFunction::Sqrt: + _data->code += "sqrt("; + break; + case UnaryFunction::Erf: + _data->code += "erf("; + break; + case UnaryFunction::Fabs: + _data->code += "fabs("; + break; + case UnaryFunction::IsGreaterEqual: + _data->code += "isgreaterequal("; + break; + case UnaryFunction::Log: + _data->code += "log("; + break; + case UnaryFunction::SizeOf: + _data->code += "sizeof("; + break; + case UnaryFunction::Round: + _data->code += "round("; + break; default: - CKW_ASSERT(false); + CKW_ASSERT_MSG(false, "Unexpected UnaryFunction used."); } _data->code += src_prefix + src->vector(y).str; @@ -3595,11 +3658,105 @@ public: } } - void op_if(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override + void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto lhs = operands.unpack(o_lhs); - auto rhs = operands.unpack(o_rhs); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *first = operands.unpack(first_name); + const IVectorTile *second = operands.unpack(second_name); + const IVectorTile *dst = operands.unpack(dst_name); + + const int32_t dst_w = dst->format().w; + const int32_t dst_h = dst->format().h; + const int32_t first_w = first->format().w; + const int32_t second_w = second->format().w; + const auto datatype = dst->underlying_source_variables()[0].type; + const std::string datatype_str = datatype.str; + + const bool broadcast_first_x = dst_w != 1 && first_w == 1; + const bool broadcast_second_x = dst_w != 1 && second_w == 1; + + const std::string first_prefix = broadcast_first_x ? "(" + datatype_str + ")" : ""; + const std::string second_prefix = broadcast_second_x ? "(" + datatype_str + ")" : ""; + + const bool is_float = (datatype.dt == DataType::Fp32 || datatype.dt == DataType::Fp16); + + // Broadcasting on Y is automatic + for(int32_t y = 0; y < dst_h; ++y) + { + _data->code += dst->vector(y).str; + _data->code += " = "; + + switch(func) + { + case BinaryFunction::Min: + _data->code += is_float ? "fmin(" : "min("; + break; + case BinaryFunction::Max: + _data->code += is_float ? "fmax(" : "max("; + break; + default: + CKW_ASSERT_MSG(false, "Unexpected BinaryFunction used."); + } + + _data->code += first_prefix + first->vector(y).str; + _data->code += ", "; + _data->code += second_prefix + second->vector(y).str; + _data->code += ");\n"; + } + } + + void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) override + { + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *first = operands.unpack(first_name); + const IVectorTile *second = operands.unpack(second_name); + const IVectorTile *third = operands.unpack(third_name); + const IVectorTile *dst = operands.unpack(dst_name); + + const int32_t dst_w = dst->format().w; + const int32_t dst_h = dst->format().h; + const int32_t first_w = first->format().w; + const int32_t second_w = second->format().w; + const int32_t third_w = third->format().w; + const std::string dt = dst->underlying_source_variables()[0].type.str; + + const bool broadcast_first_x = dst_w != 1 && first_w == 1; + const bool broadcast_second_x = dst_w != 1 && second_w == 1; + const bool broadcast_third_x = dst_w != 1 && third_w == 1; + + const std::string first_prefix = broadcast_first_x ? "(" + dt + ")" : ""; + const std::string second_prefix = broadcast_second_x ? "(" + dt + ")" : ""; + const std::string third_prefix = broadcast_third_x ? "(" + dt + ")" : ""; + + // Broadcasting on Y is automatic + for(int32_t y = 0; y < dst_h; ++y) + { + _data->code += dst->vector(y).str; + _data->code += " = "; + + switch(func) + { + case TernaryFunction::Select: + _data->code += "select("; + break; + default: + CKW_ASSERT_MSG(false, "Unexpected TernaryFunction used."); + } + + _data->code += first_prefix + first->vector(y).str; + _data->code += ", "; + _data->code += second_prefix + second->vector(y).str; + _data->code += ", "; + _data->code += third_prefix + third->vector(y).str; + _data->code += ");\n"; + } + } + + void op_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override + { + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *lhs = operands.unpack(o_lhs); + const IVectorTile *rhs = operands.unpack(o_rhs); assert(is_tile_scalar(lhs)); assert(is_tile_scalar(rhs)); @@ -3613,13 +3770,23 @@ public: _data->code += ")\n"; } - void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value_name, - AssignmentOp update_op, const Operand &update_value_name) override + void op_else_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto var = operands.unpack(var_name); - auto cond_value = operands.unpack(cond_value_name); - auto update_value = operands.unpack(update_value_name); + _data->code += "else "; + op_if_header(o_lhs, op, o_rhs); + } + + void op_else_header() override + { + _data->code += "else\n"; + } + + void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override + { + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *var = operands.unpack(var_name); + const IVectorTile *cond_value = operands.unpack(cond_value_name); + const IVectorTile *update_value = operands.unpack(update_value_name); const int32_t dst_w = var->format().w; const int32_t dst_h = var->format().h; @@ -3646,15 +3813,17 @@ public: const Operand &dilation_y) override { OperandUnpacker operands(_data->tiles, _data->arguments); - auto dst = operands.unpack(o_dst); - auto x = operands.unpack(o_x); - auto y = operands.unpack(o_y); - auto z = operands.unpack(o_z); - auto dil_y = operands.unpack(dilation_y); - auto b = operands.unpack(o_batch_idx); + + // Not const as it requires changes to 'load_writer'. + IVectorTile *dst = operands.unpack(o_dst); + IVectorTile *x = operands.unpack(o_x); + IVectorTile *y = operands.unpack(o_y); + IVectorTile *z = operands.unpack(o_z); + IVectorTile *dil_y = operands.unpack(dilation_y); + IVectorTile *b = operands.unpack(o_batch_idx); TensorOperandUnpacker tensor_operands(_data->arguments); - auto tensor = tensor_operands.unpack(o_tensor); + IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor); auto gpu_sampler = o_tensor.sampler(); GpuTensor3dMapper mapper(tensor, gpu_sampler); @@ -3682,14 +3851,16 @@ public: const Operand &o_batch_idx) override { OperandUnpacker operands(_data->tiles, _data->arguments); - auto dst = operands.unpack(o_dst); - auto x = operands.unpack(o_x); - auto y_ind = operands.unpack(o_indirect_h); - auto z = operands.unpack(o_z); - auto b = operands.unpack(o_batch_idx); + + // Not const as it requires changes to 'load_writer'. + IVectorTile *dst = operands.unpack(o_dst); + IVectorTile *x = operands.unpack(o_x); + IVectorTile *y_ind = operands.unpack(o_indirect_h); + IVectorTile *z = operands.unpack(o_z); + IVectorTile *b = operands.unpack(o_batch_idx); TensorOperandUnpacker tensor_operands(_data->arguments); - auto tensor = tensor_operands.unpack(o_tensor); + IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor); auto gpu_sampler = o_tensor.sampler(); GpuTensor3dMapper mapper(tensor, gpu_sampler); @@ -3712,14 +3883,16 @@ public: const Operand &batch_index_name) override { OperandUnpacker operands(_data->tiles, _data->arguments); - auto src = operands.unpack(src_name); - auto x = operands.unpack(x_name); - auto y = operands.unpack(y_name); - auto z = operands.unpack(z_name); - auto b = operands.unpack(batch_index_name); + + // Not const as it requires changes to 'load_writer'. + IVectorTile *src = operands.unpack(src_name); + IVectorTile *x = operands.unpack(x_name); + IVectorTile *y = operands.unpack(y_name); + IVectorTile *z = operands.unpack(z_name); + IVectorTile *b = operands.unpack(batch_index_name); TensorOperandUnpacker tensor_operands(_data->arguments); - auto tensor = tensor_operands.unpack(tensor_name); + IGpuTensorArgument *tensor = tensor_operands.unpack(tensor_name); auto gpu_sampler = tensor_name.sampler(); GpuTensor3dMapper mapper(tensor, gpu_sampler); @@ -3747,15 +3920,15 @@ public: void util_get_indirect_buffer(const Operand &o_dst, const TensorOperand &o_tensor, const Operand &o_x, const Operand &o_y, const Operand &o_x_off, const Operand &o_y_off) override { - OperandUnpacker operands(_data->tiles, _data->arguments); - auto dst = operands.unpack(o_dst); - auto x = operands.unpack(o_x); - auto y = operands.unpack(o_y); - auto x_off = operands.unpack(o_x_off); - auto y_off = operands.unpack(o_y_off); + OperandUnpacker operands(_data->tiles, _data->arguments); + const IVectorTile *dst = operands.unpack(o_dst); + const IVectorTile *x = operands.unpack(o_x); + const IVectorTile *y = operands.unpack(o_y); + const IVectorTile *x_off = operands.unpack(o_x_off); + const IVectorTile *y_off = operands.unpack(o_y_off); TensorOperandUnpacker tensor_operands(_data->arguments); - auto tensor = tensor_operands.unpack(o_tensor); + IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor); assert(dst->format().w == 1); assert(x->format().w == 1); diff --git a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp index 143d550dec..28e54df3a5 100644 --- a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp +++ b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp @@ -24,7 +24,7 @@ #include "ckw/TensorTileSampler.h" #include "ckw/TileOperand.h" -#include "ckw/Types.h" +#include "ckw/types/TensorSamplerTypes.h" namespace ckw { |