diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h | 66 |
1 files changed, 46 insertions, 20 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h index b285cc2b54..6e1291cdd5 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h @@ -30,6 +30,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/GPUTarget.h" #include "src/core/common/Macros.h" +#include "support/StringSupport.h" #include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h" @@ -191,7 +192,7 @@ public: struct TagVal { TagVal() = default; - TagVal(SharedVarTable::SharedVar var) + TagVal(const SharedVarTable::SharedVar &var) : value{ var.uniq_name } { } @@ -201,6 +202,11 @@ public: { } + TagVal(const std::string &val) + : value{ val } + { + } + std::string value{}; }; using TagLUT = std::unordered_map<Tag, TagVal>; // Used to instantiating a code template / replacing tags @@ -217,12 +223,12 @@ public: virtual std::vector<Link> get_links() const = 0; virtual std::string name() const = 0; + // @note: some tags can be unused since they could be used only for the macros, or only for the component code static std::string replace_tags(const std::string &code_template, const TagLUT &tags) { - std::string replaced_code = ""; - std::unordered_set<std::string> used_tags{}; - bool scanning_pattern = false; - std::string pattern_found = ""; + std::string replaced_code = ""; + bool scanning_pattern = false; + std::string pattern_found = ""; for(size_t i = 0; i < code_template.size() - 1; ++i) { if(!scanning_pattern) @@ -247,7 +253,6 @@ public: std::string err = "Pattern " + pattern_found + " not found in tags"; ARM_COMPUTE_ERROR_ON_MSG(tags.find(pattern_found) == tags.end(), err.c_str()); replaced_code += tags.find(pattern_found)->second.value; - used_tags.insert(pattern_found); } else { @@ -255,12 +260,7 @@ public: } } } - // Check for unused tags - for(const auto &tag : tags) - { - ARM_COMPUTE_UNUSED(tag); - ARM_COMPUTE_ERROR_ON_MSG(used_tags.find(tag.first) == used_tags.end(), "Warning: unused tags"); - } + return replaced_code; } ComponentID id() const @@ -303,6 +303,11 @@ public: return ""; } + virtual CLBuildOptions generate_build_options() const + { + return CLBuildOptions{}; + } + protected: const ClKernelBlueprint *_blueprint; @@ -445,12 +450,10 @@ public: { std::string name = ""; - auto stack = topological_sort(); - while(!stack.empty()) + traverse([&](std::stack<ComponentID> stack) { name += _components.find(stack.top())->second->name() + (stack.size() > 2 ? "___" : ""); - stack.pop(); - } + }); return name; } @@ -480,7 +483,7 @@ public: headers_list.insert(curr_headers_list.begin(), curr_headers_list.end()); if(!curr_additional_macros.empty()) // Some components might not have any { - additional_macros.insert(curr_additional_macros); + additional_macros.insert(IClKernelComponent::replace_tags(curr_additional_macros, var_lut)); } stack.pop(); @@ -524,7 +527,19 @@ public: CLBuildOptions build_options() const { - return CLBuildOptions{}; + CLBuildOptions build_opts{}; + + traverse([&](std::stack<ComponentID> stack) + { + build_opts.add_options(_components.find(stack.top())->second->generate_build_options().options()); + }); + + return build_opts; + } + + TileDescriptor get_tile_info() const + { + return _tile_info; } Window get_execution_window() const @@ -596,6 +611,17 @@ private: return stack; } + void traverse(const std::function<void(std::stack<ComponentID>)> &func) const + { + std::stack<ComponentID> stack = topological_sort(); + + while(!stack.empty()) + { + func(stack); + stack.pop(); + } + } + std::string generate_argument_declaration(const SharedVarTable::SharedVar &var) const { ARM_COMPUTE_ERROR_ON_MSG(var.group != SharedVarGroup::Argument, "An argument declaration can only be generated from a kernel argument"); @@ -672,7 +698,7 @@ private: ARM_COMPUTE_ERROR("Unsupported clipping strategy"); } - code += "\n REPEAT_VAR_INIT_TO_CONST(M0, uint, g_zout, 0);\n"; + code += "\n REPEAT_VAR_INIT_TO_CONST(" + std::to_string(tile_dim_y) + ", uint, g_zout, 0);\n"; code += " REPEAT_VAR_INIT_TO_CONST(16, uint, g_zero, 0);\n\n"; return code; @@ -684,7 +710,7 @@ private: int32_t _num_components{}; int32_t _num_complex_components{}; - ArgumentID _dst_id{ -1 }; + ArgumentID _dst_id{ -1 }; // Initially set to -1, which means the graph has no dst yet, since node IDs are positive numbers // Argument, components and intermediate tensors IDs with corresponding ptrs (except intermediate) std::unordered_map<ComponentID, ComponentUniquePtr> _components{}; |