aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp47
1 files changed, 26 insertions, 21 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
index 13c0b141a5..2eafe62bfa 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
@@ -24,6 +24,7 @@
#include "GpuKernelVariableTable.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/ITensorInfo.h"
+#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
namespace arm_compute
{
@@ -31,44 +32,48 @@ namespace experimental
{
namespace dynamic_fusion
{
-void GpuKernelVariableTable::declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm, const std::string &alias)
+void GpuKernelVariableTable::declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias)
{
ARM_COMPUTE_ERROR_ON_MSG(!tensor->has_valid_id(), "Tensor info with valid id expected");
+
// Do not re-declare if the variable associated with the tensor has already been declared
- if(get_variable(tensor).has_valid_id())
+ auto it = _vars.find(tensor->id());
+
+ if(it != _vars.end())
{
- ARM_COMPUTE_ERROR_ON(!(get_variable(tensor).kernel_argument_info == argument_info));
+ ARM_COMPUTE_ERROR_ON(!(it->second.kernel_argument_info == argument_info));
return;
}
- // Declare variable associated with the tensor
- std::stringstream ss;
- ss << alias << "_t" << tensor->id();
- const auto uniq_name = ss.str();
- TensorVariable var{ tensor->id(), uniq_name, argument_info };
- if(is_interm)
+ const auto target = comp_group.get_tile_for_tensor(tensor);
+
+ if(target != tensor)
{
- _interm_var = var;
- _interm_tensors.insert(tensor->id());
+ // If the tensor uses a shared tile, don't declare another variable.
+ it = _vars.find(target->id());
+
+ ARM_COMPUTE_ERROR_ON_MSG(
+ it == _vars.end(),
+ "The variable used for this tensor must have been declared.");
+
+ _vars[tensor->id()] = it->second;
}
else
{
+ // Declare variable associated with the tensor
+ std::stringstream ss;
+ ss << alias << "_t" << tensor->id();
+ const auto uniq_name = ss.str();
+ TensorVariable var{ tensor->id(), uniq_name, argument_info };
+
_vars.emplace(tensor->id(), var);
}
}
GpuKernelVariableTable::TensorVariable GpuKernelVariableTable::get_variable(const ITensorInfo *tensor) const
{
- const TensorVariable empty_var{};
- if(_vars.find(tensor->id()) != _vars.end())
- {
- return _vars.at(tensor->id());
- }
- if(_interm_tensors.find(tensor->id()) != _interm_tensors.end())
- {
- return _interm_var;
- }
- return empty_var;
+ const auto var = _vars.at(tensor->id());
+ return var;
}
GpuKernelVariableTable::VariableList GpuKernelVariableTable::get_variable_list(const std::vector<const ITensorInfo *> &tensors) const