diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h | 59 |
1 files changed, 55 insertions, 4 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h index 3b5160a055..b285cc2b54 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h @@ -29,6 +29,7 @@ #include "arm_compute/core/CL/CLCompileContext.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/GPUTarget.h" +#include "src/core/common/Macros.h" #include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h" @@ -63,8 +64,8 @@ enum class SharedVarGroup Automatic // Automatic variables declared within the kernel body }; -/** Specifies a shared variable ink for a component. - * It describes all the information that's availbale when a component is constructed / added: +/** Specifies a shared variable link for a component. + * It describes all the information that's available when a component is constructed / added: * e.g. its linkage (via ArgumentID and io) and its group * This is not shared variable on its own, but is used for instantiating a SharedVar when building the code */ @@ -204,6 +205,13 @@ public: }; using TagLUT = std::unordered_map<Tag, TagVal>; // Used to instantiating a code template / replacing tags public: + IClKernelComponent(const ClKernelBlueprint *blueprint) + : _blueprint(blueprint) + { + } + + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClKernelComponent); + virtual ~IClKernelComponent() = default; virtual ComponentType get_component_type() const = 0; virtual std::vector<Link> get_links() const = 0; @@ -278,6 +286,11 @@ public: { return ""; } + + virtual Window get_window() const + { + return Window{}; + } /** "Allocate" all shared variables used in a component to the @p vtable, and generate a TagLUT used to instantiate the component code * * @param vtable @@ -290,6 +303,9 @@ public: return ""; } +protected: + const ClKernelBlueprint *_blueprint; + private: ComponentID _id{}; }; @@ -398,6 +414,12 @@ public: // Additionally, set this component as one that treats this argument as "Output" (append to index 1) else { + if(component->get_component_type() == ComponentType::Store) + { + ARM_COMPUTE_ERROR_ON_MSG(_dst_id >= 0, "Trying to add more than one dst argument to the graph"); + _dst_id = arg_id; + } + for(const auto &subseq_component : _outgoing_components[arg_id]) { _component_graph[component_id].push_back(subseq_component); @@ -430,7 +452,6 @@ public: stack.pop(); } - std::cout << name << std::endl; return name; } @@ -508,7 +529,15 @@ public: Window get_execution_window() const { - return Window{}; + ARM_COMPUTE_ERROR_ON_MSG(_graph_root < 0, "No root found in the component graph"); + ARM_COMPUTE_ERROR_ON_MSG(_dst_id == -1, "Destination Tensor Id should be ready before calling get_execution_window()"); + + return _components.find(_graph_root)->second->get_window(); + } + + ArgumentID get_dst_id() const + { + return _dst_id; } ClKernelArgList get_arguments() const @@ -521,6 +550,26 @@ public: return arg_list; } + const ClTensorDescriptor *get_kernel_argument(const ArgumentID id) const + { + auto it = _kernel_arguments.find(id); + if(it != _kernel_arguments.end()) + { + return &_kernel_arguments.find(id)->second; + } + return nullptr; + } + + ITensorInfo *get_kernel_argument_info(const ArgumentID id) const + { + const ClTensorDescriptor *arg_desc = get_kernel_argument(id); + if(arg_desc != nullptr) + { + return arg_desc->tensor_info; + } + return nullptr; + } + private: void topological_sort_utility(ComponentID component_id, std::unordered_set<ComponentID> &visited, std::stack<ComponentID> &stack) const { @@ -635,6 +684,8 @@ private: int32_t _num_components{}; int32_t _num_complex_components{}; + ArgumentID _dst_id{ -1 }; + // Argument, components and intermediate tensors IDs with corresponding ptrs (except intermediate) std::unordered_map<ComponentID, ComponentUniquePtr> _components{}; std::unordered_map<ArgumentID, ClTensorDescriptor> _kernel_arguments{}; |