diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src/Kernel.cpp')
-rw-r--r-- | compute_kernel_writer/prototype/src/Kernel.cpp | 99 |
1 files changed, 94 insertions, 5 deletions
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp index 692d504887..884b69afc6 100644 --- a/compute_kernel_writer/prototype/src/Kernel.cpp +++ b/compute_kernel_writer/prototype/src/Kernel.cpp @@ -23,6 +23,7 @@ */ #include "ckw/Kernel.h" +#include "ckw/TensorOperand.h" #include "ckw/types/GpuTargetLanguage.h" #include "src/Prototype.h" @@ -30,7 +31,7 @@ namespace ckw { Kernel::Kernel(const char *name, GpuTargetLanguage language) - : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{} + : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}, _tensor_id_operands{} { } @@ -43,14 +44,102 @@ const std::string &Kernel::name() const return _name; } -const std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands() const +std::vector<KernelArgument> Kernel::arguments() const { - return _operands; + std::vector<KernelArgument> arguments; + + const auto impl_args = _kernel->arguments.tensor_argument_declarations(); + + for(auto tensor_arg : impl_args) + { + auto tensor = _tensor_id_operands.at(tensor_arg->format().id); + arguments.push_back(*tensor); + + for(auto component_arg : tensor_arg->component_declarations()) + { + switch(component_arg) + { + case TensorComponentType::OffsetFirstElement: + arguments.push_back(tensor->offset_first_element_in_bytes()); + break; + + case TensorComponentType::Stride1: + arguments.push_back(tensor->stride1()); + break; + + case TensorComponentType::Stride2: + arguments.push_back(tensor->stride2()); + break; + + case TensorComponentType::Stride3: + arguments.push_back(tensor->stride3()); + break; + + case TensorComponentType::Stride4: + arguments.push_back(tensor->stride4()); + break; + + case TensorComponentType::Dim0: + arguments.push_back(tensor->dim0()); + break; + + case TensorComponentType::Dim1: + arguments.push_back(tensor->dim1()); + break; + + case TensorComponentType::Dim2: + arguments.push_back(tensor->dim2()); + break; + + case TensorComponentType::Dim3: + arguments.push_back(tensor->dim3()); + break; + + case TensorComponentType::Dim4: + arguments.push_back(tensor->dim4()); + break; + + case TensorComponentType::Dim1xDim2: + arguments.push_back(tensor->dim1_dim2()); + break; + + case TensorComponentType::Dim1xDim2xDim3: + arguments.push_back(tensor->dim1_dim2_dim3()); + break; + + default: + CKW_ASSERT(false); + } + } + } + + return arguments; +} + +TileOperand &Kernel::register_operand(std::unique_ptr<TileOperand> operand) +{ + const auto &name = operand->name(); + auto ptr = operand.get(); + + CKW_ASSERT(_operands.find(name) == _operands.end()); + _operands[name] = std::move(operand); + + return *ptr; } -std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands() +TensorOperand &Kernel::register_operand(std::unique_ptr<TensorOperand> operand) { - return _operands; + const auto id = operand->info().id(); + const auto &name = operand->name(); + auto ptr = operand.get(); + + CKW_ASSERT(_tensor_id_operands.find(id) == _tensor_id_operands.end()); + CKW_ASSERT(_operands.find(name) == _operands.end()); + + _tensor_id_operands[id] = operand.get(); + _operands[name] = std::move(operand); + + return *ptr; } prototype::GpuKernelWriterDataHolder *Kernel::impl() |