aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src/Prototype.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src/Prototype.h')
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h389
1 files changed, 281 insertions, 108 deletions
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index fdb4ab1bab..18f284b2b1 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -31,6 +31,7 @@
#include <chrono>
#include <cmath>
#include <cstdint> // int32_t
+#include <functional>
#include <iostream> // cout (to be removed)
#include <map>
#include <memory>
@@ -41,7 +42,12 @@
#include "ckw/Error.h"
#include "ckw/TensorInfo.h"
-#include "ckw/Types.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Functions.h"
+#include "ckw/types/GpuTargetLanguage.h"
+#include "ckw/types/Operators.h"
+#include "ckw/types/TensorSamplerTypes.h"
namespace ckw
{
@@ -1548,6 +1554,18 @@ inline std::string to_string(AssignmentOp op)
}
}
+inline std::string to_string(UnaryOp op)
+{
+ switch(op)
+ {
+ case UnaryOp::LogicalNot:
+ return "!";
+ default:
+ assert(false);
+ return "";
+ }
+}
+
inline std::string to_string(BinaryOp op)
{
switch(op)
@@ -1576,8 +1594,6 @@ inline std::string to_string(BinaryOp op)
return "&&";
case BinaryOp::LogicalOr:
return "||";
- case BinaryOp::LogicalNot:
- return "!";
default:
assert(false);
return "";
@@ -2407,12 +2423,6 @@ struct GpuKernelWriterAttribute
bool return_tensor_component_by_value{ false };
};
-enum class ConvertPolicy
-{
- Wrap, /**< Wrap around */
- Saturate /**< Saturate */
-};
-
enum class RoundingMode
{
None,
@@ -2445,36 +2455,44 @@ public:
virtual void compound_statement_end() = 0;
// Operations
- virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
+ virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
+
+ virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
- virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
+ virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
- virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
+ virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
- virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
+ virtual void op_unary_expression(const Operand &dst, UnaryOp op, const Operand &src) = 0;
- virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+ virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
- virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
+ virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
- virtual void op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) = 0;
+ virtual void op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) = 0;
- virtual void op_if(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+ virtual void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) = 0;
- virtual void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+ virtual void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) = 0;
- virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+ virtual void op_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+
+ virtual void op_else_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+
+ virtual void op_else_header() = 0;
+
+ virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+
+ virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
virtual void op_load_immediate(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32), const Operand &dilation_y = Operand("1", OperandType::ScalarInt32)) = 0;
- virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+ virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
- virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
+ virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
- virtual void op_return() = 0;
+ virtual void op_return() = 0;
- // virtual void op_else() = 0;
- // virtual void op_elseif() = 0;
// Utils
// It is the process of converting
virtual void util_get_indirect_buffer(const Operand &dst, const TensorOperand &tensor, const Operand &x,
@@ -2929,10 +2947,10 @@ private:
std::string to_ls_buffer_address(const std::string &x, const std::string &y, const std::string &z,
const std::string &b) const
{
- auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
+ auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
assert(tensor_storage == GpuTensorStorage::BufferUint8Ptr);
- const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage);
- const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
+ const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage);
+ const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
std::string address;
address += "(__global ";
@@ -3135,7 +3153,6 @@ private:
auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
const std::string image2d_obj = _mapper.tensor_argument()->storage(tensor_storage);
- // const DataType dt = _dst->format().dt;
const std::string post_fix = _dst->format().dt == DataType::Fp32 ? "f" : "h";
switch(type)
@@ -3242,7 +3259,7 @@ public:
};
// This utility method needs to go in utils.h
-inline bool is_tile_scalar(IVectorTile *x)
+inline bool is_tile_scalar(const IVectorTile *x)
{
return x->format().w == 1 && x->format().h == 1;
}
@@ -3415,11 +3432,11 @@ public:
void op_get_global_batch(const Operand &o_dst, const TensorOperand &o_tensor) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *dst = operands.unpack(o_dst);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3450,13 +3467,39 @@ public:
_data->code += ");\n";
}
+ void op_unary_expression(const Operand &dst_name, UnaryOp op, const Operand &src_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
+
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+ _data->code += to_string(op);
+ _data->code += src_prefix + src->vector(y).str;
+ _data->code += ";\n";
+ }
+ }
+
void op_binary_expression(const Operand &dst_name, const Operand &lhs_name, BinaryOp op,
const Operand &rhs_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto lhs = operands.unpack(lhs_name);
- auto rhs = operands.unpack(rhs_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *lhs = operands.unpack(lhs_name);
+ const IVectorTile *rhs = operands.unpack(rhs_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
const int32_t dst_w = dst->format().w;
const int32_t dst_h = dst->format().h;
@@ -3488,12 +3531,12 @@ public:
return;
}
- bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
- bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
+ const bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
+ const bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
- std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- std::string op_str = to_string(op);
+ const std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+ const std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+ const std::string op_str = to_string(op);
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3511,21 +3554,20 @@ public:
void op_cast_expression(const Operand &o_dst, const Operand &o_src, ConvertPolicy policy) override
{
- CKW_UNUSED(policy);
-
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(o_src);
- auto dst = operands.unpack(o_dst);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(o_src);
+ const IVectorTile *dst = operands.unpack(o_dst);
// const int32_t dst_w = dst->format().w;
const int32_t dst_h = dst->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+ const std::string sat = (policy == ConvertPolicy::Saturate ? "_sat" : "");
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
{
_data->code += dst->vector(y).str;
- _data->code += " = convert_" + dt + "(";
+ _data->code += " = convert_" + dt + sat + "(";
_data->code += src->vector(y).str;
_data->code += ");\n";
}
@@ -3533,19 +3575,18 @@ public:
void op_assign(const Operand &dst_name, const Operand &src_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- // const int32_t src_h = src->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
- bool broadcast_src_x = dst_w != 1 && src_w == 1;
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
- std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3558,21 +3599,20 @@ public:
}
void
- op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) override
+ op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- // const int32_t src_h = src->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
- bool broadcast_src_x = dst_w != 1 && src_w == 1;
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
- std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3582,12 +3622,35 @@ public:
switch(func)
{
- case ScalarUnaryFunction::Exp:
+ case UnaryFunction::Exp:
_data->code += "exp(";
break;
-
+ case UnaryFunction::Tanh:
+ _data->code += "tanh(";
+ break;
+ case UnaryFunction::Sqrt:
+ _data->code += "sqrt(";
+ break;
+ case UnaryFunction::Erf:
+ _data->code += "erf(";
+ break;
+ case UnaryFunction::Fabs:
+ _data->code += "fabs(";
+ break;
+ case UnaryFunction::IsGreaterEqual:
+ _data->code += "isgreaterequal(";
+ break;
+ case UnaryFunction::Log:
+ _data->code += "log(";
+ break;
+ case UnaryFunction::SizeOf:
+ _data->code += "sizeof(";
+ break;
+ case UnaryFunction::Round:
+ _data->code += "round(";
+ break;
default:
- CKW_ASSERT(false);
+ CKW_ASSERT_MSG(false, "Unexpected UnaryFunction used.");
}
_data->code += src_prefix + src->vector(y).str;
@@ -3595,11 +3658,105 @@ public:
}
}
- void op_if(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+ void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto lhs = operands.unpack(o_lhs);
- auto rhs = operands.unpack(o_rhs);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *first = operands.unpack(first_name);
+ const IVectorTile *second = operands.unpack(second_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t first_w = first->format().w;
+ const int32_t second_w = second->format().w;
+ const auto datatype = dst->underlying_source_variables()[0].type;
+ const std::string datatype_str = datatype.str;
+
+ const bool broadcast_first_x = dst_w != 1 && first_w == 1;
+ const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + datatype_str + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + datatype_str + ")" : "";
+
+ const bool is_float = (datatype.dt == DataType::Fp32 || datatype.dt == DataType::Fp16);
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+
+ switch(func)
+ {
+ case BinaryFunction::Min:
+ _data->code += is_float ? "fmin(" : "min(";
+ break;
+ case BinaryFunction::Max:
+ _data->code += is_float ? "fmax(" : "max(";
+ break;
+ default:
+ CKW_ASSERT_MSG(false, "Unexpected BinaryFunction used.");
+ }
+
+ _data->code += first_prefix + first->vector(y).str;
+ _data->code += ", ";
+ _data->code += second_prefix + second->vector(y).str;
+ _data->code += ");\n";
+ }
+ }
+
+ void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *first = operands.unpack(first_name);
+ const IVectorTile *second = operands.unpack(second_name);
+ const IVectorTile *third = operands.unpack(third_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t first_w = first->format().w;
+ const int32_t second_w = second->format().w;
+ const int32_t third_w = third->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+
+ const bool broadcast_first_x = dst_w != 1 && first_w == 1;
+ const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+ const bool broadcast_third_x = dst_w != 1 && third_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + dt + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + dt + ")" : "";
+ const std::string third_prefix = broadcast_third_x ? "(" + dt + ")" : "";
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+
+ switch(func)
+ {
+ case TernaryFunction::Select:
+ _data->code += "select(";
+ break;
+ default:
+ CKW_ASSERT_MSG(false, "Unexpected TernaryFunction used.");
+ }
+
+ _data->code += first_prefix + first->vector(y).str;
+ _data->code += ", ";
+ _data->code += second_prefix + second->vector(y).str;
+ _data->code += ", ";
+ _data->code += third_prefix + third->vector(y).str;
+ _data->code += ");\n";
+ }
+ }
+
+ void op_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *lhs = operands.unpack(o_lhs);
+ const IVectorTile *rhs = operands.unpack(o_rhs);
assert(is_tile_scalar(lhs));
assert(is_tile_scalar(rhs));
@@ -3613,13 +3770,23 @@ public:
_data->code += ")\n";
}
- void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value_name,
- AssignmentOp update_op, const Operand &update_value_name) override
+ void op_else_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto var = operands.unpack(var_name);
- auto cond_value = operands.unpack(cond_value_name);
- auto update_value = operands.unpack(update_value_name);
+ _data->code += "else ";
+ op_if_header(o_lhs, op, o_rhs);
+ }
+
+ void op_else_header() override
+ {
+ _data->code += "else\n";
+ }
+
+ void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *var = operands.unpack(var_name);
+ const IVectorTile *cond_value = operands.unpack(cond_value_name);
+ const IVectorTile *update_value = operands.unpack(update_value_name);
const int32_t dst_w = var->format().w;
const int32_t dst_h = var->format().h;
@@ -3646,15 +3813,17 @@ public:
const Operand &dilation_y) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y = operands.unpack(o_y);
- auto z = operands.unpack(o_z);
- auto dil_y = operands.unpack(dilation_y);
- auto b = operands.unpack(o_batch_idx);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *dst = operands.unpack(o_dst);
+ IVectorTile *x = operands.unpack(o_x);
+ IVectorTile *y = operands.unpack(o_y);
+ IVectorTile *z = operands.unpack(o_z);
+ IVectorTile *dil_y = operands.unpack(dilation_y);
+ IVectorTile *b = operands.unpack(o_batch_idx);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3682,14 +3851,16 @@ public:
const Operand &o_batch_idx) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y_ind = operands.unpack(o_indirect_h);
- auto z = operands.unpack(o_z);
- auto b = operands.unpack(o_batch_idx);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *dst = operands.unpack(o_dst);
+ IVectorTile *x = operands.unpack(o_x);
+ IVectorTile *y_ind = operands.unpack(o_indirect_h);
+ IVectorTile *z = operands.unpack(o_z);
+ IVectorTile *b = operands.unpack(o_batch_idx);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3712,14 +3883,16 @@ public:
const Operand &batch_index_name) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto x = operands.unpack(x_name);
- auto y = operands.unpack(y_name);
- auto z = operands.unpack(z_name);
- auto b = operands.unpack(batch_index_name);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *src = operands.unpack(src_name);
+ IVectorTile *x = operands.unpack(x_name);
+ IVectorTile *y = operands.unpack(y_name);
+ IVectorTile *z = operands.unpack(z_name);
+ IVectorTile *b = operands.unpack(batch_index_name);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(tensor_name);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(tensor_name);
auto gpu_sampler = tensor_name.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3747,15 +3920,15 @@ public:
void util_get_indirect_buffer(const Operand &o_dst, const TensorOperand &o_tensor, const Operand &o_x,
const Operand &o_y, const Operand &o_x_off, const Operand &o_y_off) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y = operands.unpack(o_y);
- auto x_off = operands.unpack(o_x_off);
- auto y_off = operands.unpack(o_y_off);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *dst = operands.unpack(o_dst);
+ const IVectorTile *x = operands.unpack(o_x);
+ const IVectorTile *y = operands.unpack(o_y);
+ const IVectorTile *x_off = operands.unpack(o_x_off);
+ const IVectorTile *y_off = operands.unpack(o_y_off);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
assert(dst->format().w == 1);
assert(x->format().w == 1);