diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h index 4eee3963c2..82b7513c0d 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h +++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h @@ -39,9 +39,10 @@ namespace experimental { namespace dynamic_fusion { -/** A table of all the variables used in the kernel - * Since fusion is restricted to a linear sequence of components in a kernel, only a single "intermediate variable" (the accumulator) is allowed. - * Each kernel has exactly one variable table +class GpuKernelComponentGroup; + +/** A table of all the variables used in the kernel. + * Each kernel has exactly one variable table. */ class GpuKernelVariableTable { @@ -69,15 +70,12 @@ public: public: /** Declare a @ref TensorVariable for a corresponding tensor info. * - * @note: Later re-declaration of the intermediate variable will overwrite the previous association to the @ref ITensorInfo - * Therefore, the order of declaration is important. It's assumed that the components declaring the variable is already in correct order - * + * @param[in] comp_group Component group the tensor belongs to * @param[in] tensor Tensor info with which the new variable is associated * @param[in] argument_info Kernel argument information - * @param[in] is_interm If the new variable is an intermediate variable * @param[in] alias Alias for the variable. Will be used as part of the variable name */ - void declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm = false, const std::string &alias = "unnamed"); + void declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias = "unnamed"); /** Get the @ref TensorVariable associated with @p tensor * * @param[in] tensor Tensor info to be queried @@ -95,9 +93,7 @@ public: VariableList get_variable_list(const std::vector<const ITensorInfo *> &tensors) const; private: - std::map<ITensorInfo::Id, TensorVariable> _vars{}; /**< Non-intermediate (function parameter) variables*/ - TensorVariable _interm_var{}; /**< Intermediate variable */ - std::set<ITensorInfo::Id> _interm_tensors{}; /**< Tensors associated with the single intermediate variable */ + std::map<ITensorInfo::Id, TensorVariable> _vars{}; }; /** A tag value will substitute a tag in a string template during its instantiation */ |