aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h')
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
index 4eee3963c2..82b7513c0d 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
@@ -39,9 +39,10 @@ namespace experimental
{
namespace dynamic_fusion
{
-/** A table of all the variables used in the kernel
- * Since fusion is restricted to a linear sequence of components in a kernel, only a single "intermediate variable" (the accumulator) is allowed.
- * Each kernel has exactly one variable table
+class GpuKernelComponentGroup;
+
+/** A table of all the variables used in the kernel.
+ * Each kernel has exactly one variable table.
*/
class GpuKernelVariableTable
{
@@ -69,15 +70,12 @@ public:
public:
/** Declare a @ref TensorVariable for a corresponding tensor info.
*
- * @note: Later re-declaration of the intermediate variable will overwrite the previous association to the @ref ITensorInfo
- * Therefore, the order of declaration is important. It's assumed that the components declaring the variable is already in correct order
- *
+ * @param[in] comp_group Component group the tensor belongs to
* @param[in] tensor Tensor info with which the new variable is associated
* @param[in] argument_info Kernel argument information
- * @param[in] is_interm If the new variable is an intermediate variable
* @param[in] alias Alias for the variable. Will be used as part of the variable name
*/
- void declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm = false, const std::string &alias = "unnamed");
+ void declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias = "unnamed");
/** Get the @ref TensorVariable associated with @p tensor
*
* @param[in] tensor Tensor info to be queried
@@ -95,9 +93,7 @@ public:
VariableList get_variable_list(const std::vector<const ITensorInfo *> &tensors) const;
private:
- std::map<ITensorInfo::Id, TensorVariable> _vars{}; /**< Non-intermediate (function parameter) variables*/
- TensorVariable _interm_var{}; /**< Intermediate variable */
- std::set<ITensorInfo::Id> _interm_tensors{}; /**< Tensors associated with the single intermediate variable */
+ std::map<ITensorInfo::Id, TensorVariable> _vars{};
};
/** A tag value will substitute a tag in a string template during its instantiation */