diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp index d5c03c60c5..d78956f835 100644 --- a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp +++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp @@ -30,6 +30,7 @@ #include "arm_compute/core/Window.h" #include "src/common/utils/Log.h" #include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.h" +#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/utils/TypeConverter.h" #include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwKernelWriter.h" #include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwScopedKernelWriter.h" @@ -42,29 +43,24 @@ namespace experimental namespace dynamic_fusion { GpuCkwDriver::GpuCkwDriver(const GpuKernelComponentGroup &components) - : _components{ components } + : _components{ components }, _kernel{ GpuTargetLanguage::OpenCL } { } std::string GpuCkwDriver::get_name() { ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO")); - return "todo_get_name"; + return "unnamed"; } std::string GpuCkwDriver::get_code() { - ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO")); - ckw::Kernel kernel(get_name().c_str(), GpuTargetLanguage::OpenCL); - GpuCkwKernelWriter root_writer(kernel); + _kernel.name(get_name()); + GpuCkwKernelWriter root_writer(_kernel); GpuCkwScopedKernelWriter writer(&root_writer); GpuCkwVariableTable vtable{}; // Global Kernel Writer Driver code - - // The following is just an incomplete example of using the kernel writer - - // Iterate over component specific Ckw Driver; generate component code and concatenate them for(auto &comp : _components) { auto ckw_driver = comp->ckw_component_driver(); @@ -96,18 +92,31 @@ Window GpuCkwDriver::get_window() const return root_comp->ckw_component_driver()->get_window(); } -std::map<ITensorInfo::Id, GpuKernelArgument> GpuCkwDriver::get_tensors() +GpuKernelArgumentList GpuCkwDriver::get_kernel_arguments() { - ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO")); - // Assemble GpuKernelArguments - std::map<ITensorInfo::Id, GpuKernelArgument> tensors; - for(const auto t : _components.get_argument_tensors()) + GpuKernelArgumentList args{}; + for(const auto &arg : _kernel.arguments()) { - tensors.emplace( - t->id(), - GpuKernelArgument{ *t, { GpuKernelArgumentInfo::Type::Tensor_Special_0 } }); + switch(arg.type()) + { + case KernelArgument::Type::TensorStorage: + { + args.emplace_back(static_cast<ITensorInfo::Id>(arg.id()), from_ckw(arg.tensor_storage_type())); + break; + } + case KernelArgument::Type::TensorComponent: + { + args.emplace_back(static_cast<ITensorInfo::Id>(arg.id()), from_ckw(arg.tensor_component_type())); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported KernelArgument Type"); + break; + } + } } - return tensors; + return args; } } // namespace dynamic_fusion |