aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h46
1 files changed, 37 insertions, 9 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
index 4c720ea1aa..e24c742fd7 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
@@ -31,6 +31,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/GPUTarget.h"
#include "src/core/common/Macros.h"
+#include "support/Requires.h"
#include "support/StringSupport.h"
#include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h"
@@ -198,8 +199,9 @@ public:
{
}
- TagVal(ComponentID id)
- : value{ std::to_string(id) }
+ template <typename T, ARM_COMPUTE_REQUIRES_TA(std::is_integral<T>::value)>
+ TagVal(T val)
+ : value{ support::cpp11::to_string(val) }
{
}
@@ -208,6 +210,16 @@ public:
{
}
+ TagVal(const char *val)
+ : value{ std::string(val) }
+ {
+ }
+
+ TagVal(const DataType &data_type)
+ : value{ get_cl_type_from_data_type(data_type) }
+ {
+ }
+
std::string value{};
};
using TagLUT = std::unordered_map<Tag, TagVal>; // Used to instantiating a code template / replacing tags
@@ -633,21 +645,36 @@ private:
std::string code;
switch(var.desc.tensor_arg_type)
{
+ case TensorArgType::Vector:
+ {
+ code += "\n VECTOR_DECLARATION(" + var.uniq_name + ")";
+ break;
+ }
case TensorArgType::Image:
{
- code += "IMAGE_DECLARATION(" + var.uniq_name + ")";
+ code += "\n IMAGE_DECLARATION(" + var.uniq_name + ")";
break;
}
case TensorArgType::Image_3D:
{
- code += "IMAGE_DECLARATION(" + var.uniq_name + "),\n";
- code += "uint " + var.uniq_name + "_stride_z";
+ code += "\n IMAGE_DECLARATION(" + var.uniq_name + "),";
+ code += "\n uint " + var.uniq_name + "_stride_z";
break;
}
case TensorArgType::Image_3D_Export_To_ClImage2D:
{
- code += "__read_only image2d_t " + var.uniq_name + "_img,\n";
- code += "uint " + var.uniq_name + "_stride_z,\n";
+ code += "\n __read_only image2d_t " + var.uniq_name + "_img,";
+ code += "\n uint " + var.uniq_name + "_stride_z";
+ break;
+ }
+ case TensorArgType::Tensor_4D_t_Buffer:
+ {
+ code += "\n TENSOR4D_T(" + var.uniq_name + ", BUFFER)";
+ break;
+ }
+ case TensorArgType::Tensor_4D_t_Image:
+ {
+ code += "\n TENSOR4D_T(" + var.uniq_name + ", IMAGE)";
break;
}
default:
@@ -664,7 +691,7 @@ private:
for(const auto &arg : argument_list)
{
- code += "\n " + generate_argument_declaration(arg) + ",";
+ code += generate_argument_declaration(arg) + ",";
}
code[code.length() - 1] = ')';
@@ -674,7 +701,8 @@ private:
std::string generate_global_section() const
{
- std::string code = " uint g_x = get_global_id(0);\n";
+ std::string code = "";
+ code += " uint g_x = get_global_id(0);\n";
code += " uint g_y = get_global_id(1);\n";
code += " uint g_z = get_global_id(2);\n\n";