aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h59
1 files changed, 55 insertions, 4 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
index 3b5160a055..b285cc2b54 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h
@@ -29,6 +29,7 @@
#include "arm_compute/core/CL/CLCompileContext.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/GPUTarget.h"
+#include "src/core/common/Macros.h"
#include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h"
@@ -63,8 +64,8 @@ enum class SharedVarGroup
Automatic // Automatic variables declared within the kernel body
};
-/** Specifies a shared variable ink for a component.
- * It describes all the information that's availbale when a component is constructed / added:
+/** Specifies a shared variable link for a component.
+ * It describes all the information that's available when a component is constructed / added:
* e.g. its linkage (via ArgumentID and io) and its group
* This is not shared variable on its own, but is used for instantiating a SharedVar when building the code
*/
@@ -204,6 +205,13 @@ public:
};
using TagLUT = std::unordered_map<Tag, TagVal>; // Used to instantiating a code template / replacing tags
public:
+ IClKernelComponent(const ClKernelBlueprint *blueprint)
+ : _blueprint(blueprint)
+ {
+ }
+
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClKernelComponent);
+
virtual ~IClKernelComponent() = default;
virtual ComponentType get_component_type() const = 0;
virtual std::vector<Link> get_links() const = 0;
@@ -278,6 +286,11 @@ public:
{
return "";
}
+
+ virtual Window get_window() const
+ {
+ return Window{};
+ }
/** "Allocate" all shared variables used in a component to the @p vtable, and generate a TagLUT used to instantiate the component code
*
* @param vtable
@@ -290,6 +303,9 @@ public:
return "";
}
+protected:
+ const ClKernelBlueprint *_blueprint;
+
private:
ComponentID _id{};
};
@@ -398,6 +414,12 @@ public:
// Additionally, set this component as one that treats this argument as "Output" (append to index 1)
else
{
+ if(component->get_component_type() == ComponentType::Store)
+ {
+ ARM_COMPUTE_ERROR_ON_MSG(_dst_id >= 0, "Trying to add more than one dst argument to the graph");
+ _dst_id = arg_id;
+ }
+
for(const auto &subseq_component : _outgoing_components[arg_id])
{
_component_graph[component_id].push_back(subseq_component);
@@ -430,7 +452,6 @@ public:
stack.pop();
}
- std::cout << name << std::endl;
return name;
}
@@ -508,7 +529,15 @@ public:
Window get_execution_window() const
{
- return Window{};
+ ARM_COMPUTE_ERROR_ON_MSG(_graph_root < 0, "No root found in the component graph");
+ ARM_COMPUTE_ERROR_ON_MSG(_dst_id == -1, "Destination Tensor Id should be ready before calling get_execution_window()");
+
+ return _components.find(_graph_root)->second->get_window();
+ }
+
+ ArgumentID get_dst_id() const
+ {
+ return _dst_id;
}
ClKernelArgList get_arguments() const
@@ -521,6 +550,26 @@ public:
return arg_list;
}
+ const ClTensorDescriptor *get_kernel_argument(const ArgumentID id) const
+ {
+ auto it = _kernel_arguments.find(id);
+ if(it != _kernel_arguments.end())
+ {
+ return &_kernel_arguments.find(id)->second;
+ }
+ return nullptr;
+ }
+
+ ITensorInfo *get_kernel_argument_info(const ArgumentID id) const
+ {
+ const ClTensorDescriptor *arg_desc = get_kernel_argument(id);
+ if(arg_desc != nullptr)
+ {
+ return arg_desc->tensor_info;
+ }
+ return nullptr;
+ }
+
private:
void topological_sort_utility(ComponentID component_id, std::unordered_set<ComponentID> &visited, std::stack<ComponentID> &stack) const
{
@@ -635,6 +684,8 @@ private:
int32_t _num_components{};
int32_t _num_complex_components{};
+ ArgumentID _dst_id{ -1 };
+
// Argument, components and intermediate tensors IDs with corresponding ptrs (except intermediate)
std::unordered_map<ComponentID, ComponentUniquePtr> _components{};
std::unordered_map<ArgumentID, ClTensorDescriptor> _kernel_arguments{};