aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src/KernelWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src/KernelWriter.cpp')
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp37
1 files changed, 18 insertions, 19 deletions
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 73458efa1d..1ac9ede5b5 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -24,6 +24,7 @@
#include "ckw/KernelWriter.h"
#include "ckw/Error.h"
+#include "ckw/TensorInfo.h"
#include "ckw/TensorOperand.h"
#include "src/Prototype.h"
@@ -85,26 +86,24 @@ int32_t KernelWriter::next_id_space()
// Tensor and tile declaration
// =================================================================================================
-TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
{
const auto var_name = generate_variable_name(name);
_impl->declare_argument(var_name, create_impl_tensor_info(info));
- auto operand = new TensorOperand(var_name, info);
- register_operand(operand, false);
+ auto &operand = _kernel->register_operand(std::make_unique<TensorOperand>(var_name, info, storage_type));
- return *operand;
+ return operand;
}
TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
{
const auto var_name = generate_variable_name(name);
- auto operand = new TileOperand(var_name, value);
- register_operand(operand, false);
+ auto &operand = _kernel->register_operand(std::make_unique<TileOperand>(var_name, value));
- return *operand;
+ return operand;
}
std::string KernelWriter::generate_variable_name(const std::string &name) const
@@ -116,21 +115,21 @@ std::string KernelWriter::generate_variable_name(const std::string &name) const
return var_name.str();
}
-void KernelWriter::register_operand(OperandBase *operand, bool declaring)
+TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> operand_ptr)
{
- const auto &name = operand->name();
- auto &operands = _kernel->operands();
+ auto &operand = _kernel->register_operand(std::move(operand_ptr));
+ const auto &name = operand.name();
- CKW_ASSERT(operands.find(name) == operands.end());
- operands[name] = std::unique_ptr<OperandBase>(operand);
-
- if(declaring && !operand->is_constant())
+ if(!operand.is_constant())
{
- const auto tile = reinterpret_cast<TileOperand *>(operand);
+ const auto &info = operand.tile_info();
- const auto &info = tile->tile_info();
- _impl->declare_tile(tile->name(), prototype::TileInfo(info.data_type(), info.width(), info.height()));
+ _impl->declare_tile(
+ name,
+ prototype::TileInfo(info.data_type(), info.width(), info.height()));
}
+
+ return operand;
}
// =================================================================================================
@@ -143,7 +142,7 @@ void KernelWriter::op_load(TileOperand &tile, TensorOperand &tensor, const Tenso
tensor.name(),
prototype::GpuSampler{
sampler.format(),
- prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
sampler.address_mode_x(),
sampler.address_mode_y(),
sampler.address_mode_z() });
@@ -164,7 +163,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons
tensor.name(),
prototype::GpuSampler{
sampler.format(),
- prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
sampler.address_mode_x(),
sampler.address_mode_y(),
sampler.address_mode_z() });