aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src')
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp58
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h18
-rw-r--r--compute_kernel_writer/prototype/src/TileOperand.cpp57
3 files changed, 109 insertions, 24 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 9122e518b4..f29cf12802 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -128,6 +128,10 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope
name,
prototype::TileInfo(info.data_type(), info.width(), info.height()));
}
+ else
+ {
+ _impl->declare_const_tile(name, operand.value(), operand.data_type());
+ }
return operand;
}
@@ -136,7 +140,7 @@ TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> ope
// Load and store
// =================================================================================================
-void KernelWriter::op_load(TileOperand &tile, TensorOperand &tensor, const TensorTileSampler &sampler)
+void KernelWriter::op_load(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler, const TileOperand &dilation_y)
{
prototype::TensorOperand impl_tensor(
tensor.name(),
@@ -152,9 +156,59 @@ void KernelWriter::op_load(TileOperand &tile, TensorOperand &tensor, const Tenso
auto impl_z = sampler.z().create_impl_operand(_impl.get());
auto impl_b = sampler.b().create_impl_operand(_impl.get());
+ auto impl_dilation_y = dilation_y.create_impl_operand(_impl.get());
+
+ auto impl_dst = tile.create_impl_operand(_impl.get());
+
+ _impl->op_load_immediate(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b, impl_dilation_y);
+}
+
+void KernelWriter::op_load_indirect(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler)
+{
+ prototype::TensorOperand impl_tensor(
+ tensor.name(),
+ prototype::GpuSampler{
+ sampler.format(),
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(),
+ sampler.address_mode_y(),
+ sampler.address_mode_z() });
+
+ auto impl_x = sampler.x().create_impl_operand(_impl.get());
+ auto impl_y = sampler.y().create_impl_operand(_impl.get());
+ auto impl_z = sampler.z().create_impl_operand(_impl.get());
+ auto impl_b = sampler.b().create_impl_operand(_impl.get());
+
+ auto impl_dst = tile.create_impl_operand(_impl.get());
+
+ _impl->op_load_indirect(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b);
+}
+
+void KernelWriter::util_get_indirect_buffer(TileOperand &tile,
+ const TensorOperand &tensor,
+ const TensorTileSampler &sampler,
+ const TileOperand &x,
+ const TileOperand &y,
+ const TileOperand &x_off,
+ const TileOperand &y_off)
+{
+ prototype::TensorOperand impl_tensor(
+ tensor.name(),
+ prototype::GpuSampler{
+ sampler.format(),
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
+ sampler.address_mode_x(),
+ sampler.address_mode_y(),
+ sampler.address_mode_z() });
+
+ auto impl_x = x.create_impl_operand(_impl.get());
+ auto impl_y = y.create_impl_operand(_impl.get());
+ auto impl_x_off = x_off.create_impl_operand(_impl.get());
+ auto impl_y_off = y_off.create_impl_operand(_impl.get());
+
auto impl_dst = tile.create_impl_operand(_impl.get());
- _impl->op_load_immediate(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b);
+ _impl->util_get_indirect_buffer(impl_dst, impl_tensor, impl_x, impl_y, impl_x_off, impl_y_off);
}
void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler)
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index a8dc7fbfdb..2b519471ac 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -3009,7 +3009,7 @@ private:
address += " + (";
address += x + ") * sizeof(" + dst_type + ")";
}
- if(y != "0" && (_mapper.is_one_component_y() != true))
+ if(y != "0")
{
const std::string stride_y = _mapper.tensor_component_stride_y();
address += " + (";
@@ -3249,7 +3249,7 @@ private:
std::string coord_x = "(" + x + ") >> 2";
std::string coord_y = "(";
- if(y != "0" && (_mapper.is_one_component_y() != true))
+ if(y != "0")
{
coord_y += y;
}
@@ -4024,13 +4024,6 @@ public:
_data->code += ", ";
_data->code += x_s->scalar(0, i).str;
_data->code += " >= 0);\n";
- // mi_0 = select(wxh, mi_0, y_s >= 0);
- _data->code += dst->scalar(0, i).str;
- _data->code += " = select(-1, ";
- _data->code += dst->scalar(0, i).str;
- _data->code += ", ";
- _data->code += y_s->scalar(0, i).str;
- _data->code += " >= 0);\n";
// mi_0 = select(wxh, mi_0, x_s < width);
_data->code += dst->scalar(0, i).str;
_data->code += " = select(-1, ";
@@ -4039,6 +4032,13 @@ public:
_data->code += x_s->scalar(0, i).str;
_data->code += " < ";
_data->code += width + ");\n";
+ // mi_0 = select(wxh, mi_0, y_s >= 0);
+ _data->code += dst->scalar(0, i).str;
+ _data->code += " = select(-1, ";
+ _data->code += dst->scalar(0, i).str;
+ _data->code += ", ";
+ _data->code += y_s->scalar(0, i).str;
+ _data->code += " >= 0);\n";
// mi_0 = select(wxh, mi_0, y_s < height);
_data->code += dst->scalar(0, i).str;
_data->code += " = select(-1, ";
diff --git a/compute_kernel_writer/prototype/src/TileOperand.cpp b/compute_kernel_writer/prototype/src/TileOperand.cpp
index fcb3cb6415..bf6a15b9df 100644
--- a/compute_kernel_writer/prototype/src/TileOperand.cpp
+++ b/compute_kernel_writer/prototype/src/TileOperand.cpp
@@ -30,22 +30,42 @@ namespace ckw
{
TileOperand::TileOperand(const std::string &name, const TileInfo &info)
- : OperandBase(name), _info(info), _value{ 0 }, _constant(false)
+ : OperandBase(name),
+ _info(info),
+ _value{ std::vector<std::string>{ "0" } },
+ _constant(false)
{
}
TileOperand::TileOperand(const std::string &name, DataType data_type)
- : OperandBase(name), _info(TileInfo{ data_type }), _value(0), _constant(false)
+ : OperandBase(name),
+ _info(TileInfo{ data_type }),
+ _value{ std::vector<std::string>{ "0" } },
+ _constant(false)
{
}
TileOperand::TileOperand(const std::string &name, int32_t value)
- : OperandBase(name), _info(TileInfo{ DataType::Int32 }), _value(value), _constant(true)
+ : OperandBase(name),
+ _info(TileInfo{ DataType::Int32 }),
+ _value{ std::vector<std::string>{ std::to_string(value) } },
+ _constant(true)
{
}
TileOperand::TileOperand(const std::string &name, float value)
- : OperandBase(name), _info(TileInfo{ DataType::Fp32 }), _value(value), _constant(true)
+ : OperandBase(name),
+ _info(TileInfo{ DataType::Fp32 }),
+ _value{ std::vector<std::string>{ std::to_string(value) } },
+ _constant(true)
+{
+}
+
+TileOperand::TileOperand(const std::string &name, const TileContainer &vals, DataType dt)
+ : OperandBase(name),
+ _info(TileInfo{ dt, static_cast<int32_t>(vals.size()), static_cast<int32_t>(vals[0].size()) }),
+ _value(vals),
+ _constant(true)
{
}
@@ -55,17 +75,23 @@ prototype::Operand TileOperand::create_impl_operand(prototype::IGpuKernelWriter
if(_constant)
{
- switch(_info.data_type())
+ if(is_scalar())
{
- case DataType::Int32:
- return prototype::Operand(std::to_string(_value.get<int32_t>()),
- prototype::OperandType::ScalarInt32);
+ switch(_info.data_type())
+ {
+ case DataType::Int32:
+ return prototype::Operand(_value[0][0], prototype::OperandType::ScalarInt32);
- case DataType::Fp32:
- return prototype::Operand(std::to_string(_value.get<float>()), prototype::OperandType::ScalarFp32);
+ case DataType::Fp32:
+ return prototype::Operand(_value[0][0], prototype::OperandType::ScalarFp32);
- default:
- CKW_ASSERT(false);
+ default:
+ CKW_ASSERT(false);
+ }
+ }
+ else
+ {
+ return prototype::Operand(name());
}
}
else
@@ -94,11 +120,16 @@ bool TileOperand::is_scalar() const
return _info.width() == 1 && _info.height() == 1;
}
-ScalarValue TileOperand::scalar_value() const
+std::string TileOperand::scalar_value() const
{
CKW_ASSERT(is_scalar());
CKW_ASSERT(is_constant());
+ return _value[0][0];
+}
+
+const TileContainer &TileOperand::value() const
+{
return _value;
}