aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp45
1 files changed, 27 insertions, 18 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp
index d5c03c60c5..d78956f835 100644
--- a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp
+++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp
@@ -30,6 +30,7 @@
#include "arm_compute/core/Window.h"
#include "src/common/utils/Log.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.h"
+#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/utils/TypeConverter.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwKernelWriter.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwScopedKernelWriter.h"
@@ -42,29 +43,24 @@ namespace experimental
namespace dynamic_fusion
{
GpuCkwDriver::GpuCkwDriver(const GpuKernelComponentGroup &components)
- : _components{ components }
+ : _components{ components }, _kernel{ GpuTargetLanguage::OpenCL }
{
}
std::string GpuCkwDriver::get_name()
{
ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO"));
- return "todo_get_name";
+ return "unnamed";
}
std::string GpuCkwDriver::get_code()
{
- ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO"));
- ckw::Kernel kernel(get_name().c_str(), GpuTargetLanguage::OpenCL);
- GpuCkwKernelWriter root_writer(kernel);
+ _kernel.name(get_name());
+ GpuCkwKernelWriter root_writer(_kernel);
GpuCkwScopedKernelWriter writer(&root_writer);
GpuCkwVariableTable vtable{};
// Global Kernel Writer Driver code
-
- // The following is just an incomplete example of using the kernel writer
-
- // Iterate over component specific Ckw Driver; generate component code and concatenate them
for(auto &comp : _components)
{
auto ckw_driver = comp->ckw_component_driver();
@@ -96,18 +92,31 @@ Window GpuCkwDriver::get_window() const
return root_comp->ckw_component_driver()->get_window();
}
-std::map<ITensorInfo::Id, GpuKernelArgument> GpuCkwDriver::get_tensors()
+GpuKernelArgumentList GpuCkwDriver::get_kernel_arguments()
{
- ARM_COMPUTE_LOG_PARAMS(std::string("[V1] TODO"));
- // Assemble GpuKernelArguments
- std::map<ITensorInfo::Id, GpuKernelArgument> tensors;
- for(const auto t : _components.get_argument_tensors())
+ GpuKernelArgumentList args{};
+ for(const auto &arg : _kernel.arguments())
{
- tensors.emplace(
- t->id(),
- GpuKernelArgument{ *t, { GpuKernelArgumentInfo::Type::Tensor_Special_0 } });
+ switch(arg.type())
+ {
+ case KernelArgument::Type::TensorStorage:
+ {
+ args.emplace_back(static_cast<ITensorInfo::Id>(arg.id()), from_ckw(arg.tensor_storage_type()));
+ break;
+ }
+ case KernelArgument::Type::TensorComponent:
+ {
+ args.emplace_back(static_cast<ITensorInfo::Id>(arg.id()), from_ckw(arg.tensor_component_type()));
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported KernelArgument Type");
+ break;
+ }
+ }
}
- return tensors;
+ return args;
}
} // namespace dynamic_fusion