aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h66
1 files changed, 46 insertions, 20 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
index b285cc2b54..6e1291cdd5 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
@@ -30,6 +30,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/GPUTarget.h"
#include "src/core/common/Macros.h"
+#include "support/StringSupport.h"
#include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h"
@@ -191,7 +192,7 @@ public:
struct TagVal
{
TagVal() = default;
- TagVal(SharedVarTable::SharedVar var)
+ TagVal(const SharedVarTable::SharedVar &var)
: value{ var.uniq_name }
{
}
@@ -201,6 +202,11 @@ public:
{
}
+ TagVal(const std::string &val)
+ : value{ val }
+ {
+ }
+
std::string value{};
};
using TagLUT = std::unordered_map<Tag, TagVal>; // Used to instantiating a code template / replacing tags
@@ -217,12 +223,12 @@ public:
virtual std::vector<Link> get_links() const = 0;
virtual std::string name() const = 0;
+ // @note: some tags can be unused since they could be used only for the macros, or only for the component code
static std::string replace_tags(const std::string &code_template, const TagLUT &tags)
{
- std::string replaced_code = "";
- std::unordered_set<std::string> used_tags{};
- bool scanning_pattern = false;
- std::string pattern_found = "";
+ std::string replaced_code = "";
+ bool scanning_pattern = false;
+ std::string pattern_found = "";
for(size_t i = 0; i < code_template.size() - 1; ++i)
{
if(!scanning_pattern)
@@ -247,7 +253,6 @@ public:
std::string err = "Pattern " + pattern_found + " not found in tags";
ARM_COMPUTE_ERROR_ON_MSG(tags.find(pattern_found) == tags.end(), err.c_str());
replaced_code += tags.find(pattern_found)->second.value;
- used_tags.insert(pattern_found);
}
else
{
@@ -255,12 +260,7 @@ public:
}
}
}
- // Check for unused tags
- for(const auto &tag : tags)
- {
- ARM_COMPUTE_UNUSED(tag);
- ARM_COMPUTE_ERROR_ON_MSG(used_tags.find(tag.first) == used_tags.end(), "Warning: unused tags");
- }
+
return replaced_code;
}
ComponentID id() const
@@ -303,6 +303,11 @@ public:
return "";
}
+ virtual CLBuildOptions generate_build_options() const
+ {
+ return CLBuildOptions{};
+ }
+
protected:
const ClKernelBlueprint *_blueprint;
@@ -445,12 +450,10 @@ public:
{
std::string name = "";
- auto stack = topological_sort();
- while(!stack.empty())
+ traverse([&](std::stack<ComponentID> stack)
{
name += _components.find(stack.top())->second->name() + (stack.size() > 2 ? "___" : "");
- stack.pop();
- }
+ });
return name;
}
@@ -480,7 +483,7 @@ public:
headers_list.insert(curr_headers_list.begin(), curr_headers_list.end());
if(!curr_additional_macros.empty()) // Some components might not have any
{
- additional_macros.insert(curr_additional_macros);
+ additional_macros.insert(IClKernelComponent::replace_tags(curr_additional_macros, var_lut));
}
stack.pop();
@@ -524,7 +527,19 @@ public:
CLBuildOptions build_options() const
{
- return CLBuildOptions{};
+ CLBuildOptions build_opts{};
+
+ traverse([&](std::stack<ComponentID> stack)
+ {
+ build_opts.add_options(_components.find(stack.top())->second->generate_build_options().options());
+ });
+
+ return build_opts;
+ }
+
+ TileDescriptor get_tile_info() const
+ {
+ return _tile_info;
}
Window get_execution_window() const
@@ -596,6 +611,17 @@ private:
return stack;
}
+ void traverse(const std::function<void(std::stack<ComponentID>)> &func) const
+ {
+ std::stack<ComponentID> stack = topological_sort();
+
+ while(!stack.empty())
+ {
+ func(stack);
+ stack.pop();
+ }
+ }
+
std::string generate_argument_declaration(const SharedVarTable::SharedVar &var) const
{
ARM_COMPUTE_ERROR_ON_MSG(var.group != SharedVarGroup::Argument, "An argument declaration can only be generated from a kernel argument");
@@ -672,7 +698,7 @@ private:
ARM_COMPUTE_ERROR("Unsupported clipping strategy");
}
- code += "\n REPEAT_VAR_INIT_TO_CONST(M0, uint, g_zout, 0);\n";
+ code += "\n REPEAT_VAR_INIT_TO_CONST(" + std::to_string(tile_dim_y) + ", uint, g_zout, 0);\n";
code += " REPEAT_VAR_INIT_TO_CONST(16, uint, g_zero, 0);\n\n";
return code;
@@ -684,7 +710,7 @@ private:
int32_t _num_components{};
int32_t _num_complex_components{};
- ArgumentID _dst_id{ -1 };
+ ArgumentID _dst_id{ -1 }; // Initially set to -1, which means the graph has no dst yet, since node IDs are positive numbers
// Argument, components and intermediate tensors IDs with corresponding ptrs (except intermediate)
std::unordered_map<ComponentID, ComponentUniquePtr> _components{};