aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src/Kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/src/Kernel.cpp')
-rw-r--r--compute_kernel_writer/prototype/src/Kernel.cpp99
1 files changed, 94 insertions, 5 deletions
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
index 692d504887..884b69afc6 100644
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ b/compute_kernel_writer/prototype/src/Kernel.cpp
@@ -23,6 +23,7 @@
*/
#include "ckw/Kernel.h"
+#include "ckw/TensorOperand.h"
#include "ckw/types/GpuTargetLanguage.h"
#include "src/Prototype.h"
@@ -30,7 +31,7 @@ namespace ckw
{
Kernel::Kernel(const char *name, GpuTargetLanguage language)
- : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}
+ : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}, _tensor_id_operands{}
{
}
@@ -43,14 +44,102 @@ const std::string &Kernel::name() const
return _name;
}
-const std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands() const
+std::vector<KernelArgument> Kernel::arguments() const
{
- return _operands;
+ std::vector<KernelArgument> arguments;
+
+ const auto impl_args = _kernel->arguments.tensor_argument_declarations();
+
+ for(auto tensor_arg : impl_args)
+ {
+ auto tensor = _tensor_id_operands.at(tensor_arg->format().id);
+ arguments.push_back(*tensor);
+
+ for(auto component_arg : tensor_arg->component_declarations())
+ {
+ switch(component_arg)
+ {
+ case TensorComponentType::OffsetFirstElement:
+ arguments.push_back(tensor->offset_first_element_in_bytes());
+ break;
+
+ case TensorComponentType::Stride1:
+ arguments.push_back(tensor->stride1());
+ break;
+
+ case TensorComponentType::Stride2:
+ arguments.push_back(tensor->stride2());
+ break;
+
+ case TensorComponentType::Stride3:
+ arguments.push_back(tensor->stride3());
+ break;
+
+ case TensorComponentType::Stride4:
+ arguments.push_back(tensor->stride4());
+ break;
+
+ case TensorComponentType::Dim0:
+ arguments.push_back(tensor->dim0());
+ break;
+
+ case TensorComponentType::Dim1:
+ arguments.push_back(tensor->dim1());
+ break;
+
+ case TensorComponentType::Dim2:
+ arguments.push_back(tensor->dim2());
+ break;
+
+ case TensorComponentType::Dim3:
+ arguments.push_back(tensor->dim3());
+ break;
+
+ case TensorComponentType::Dim4:
+ arguments.push_back(tensor->dim4());
+ break;
+
+ case TensorComponentType::Dim1xDim2:
+ arguments.push_back(tensor->dim1_dim2());
+ break;
+
+ case TensorComponentType::Dim1xDim2xDim3:
+ arguments.push_back(tensor->dim1_dim2_dim3());
+ break;
+
+ default:
+ CKW_ASSERT(false);
+ }
+ }
+ }
+
+ return arguments;
+}
+
+TileOperand &Kernel::register_operand(std::unique_ptr<TileOperand> operand)
+{
+ const auto &name = operand->name();
+ auto ptr = operand.get();
+
+ CKW_ASSERT(_operands.find(name) == _operands.end());
+ _operands[name] = std::move(operand);
+
+ return *ptr;
}
-std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands()
+TensorOperand &Kernel::register_operand(std::unique_ptr<TensorOperand> operand)
{
- return _operands;
+ const auto id = operand->info().id();
+ const auto &name = operand->name();
+ auto ptr = operand.get();
+
+ CKW_ASSERT(_tensor_id_operands.find(id) == _tensor_id_operands.end());
+ CKW_ASSERT(_operands.find(name) == _operands.end());
+
+ _tensor_id_operands[id] = operand.get();
+ _operands[name] = std::move(operand);
+
+ return *ptr;
}
prototype::GpuKernelWriterDataHolder *Kernel::impl()