diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp | 47 |
1 files changed, 26 insertions, 21 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp index 13c0b141a5..2eafe62bfa 100644 --- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp +++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp @@ -24,6 +24,7 @@ #include "GpuKernelVariableTable.h" #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/ITensorInfo.h" +#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h" namespace arm_compute { @@ -31,44 +32,48 @@ namespace experimental { namespace dynamic_fusion { -void GpuKernelVariableTable::declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm, const std::string &alias) +void GpuKernelVariableTable::declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias) { ARM_COMPUTE_ERROR_ON_MSG(!tensor->has_valid_id(), "Tensor info with valid id expected"); + // Do not re-declare if the variable associated with the tensor has already been declared - if(get_variable(tensor).has_valid_id()) + auto it = _vars.find(tensor->id()); + + if(it != _vars.end()) { - ARM_COMPUTE_ERROR_ON(!(get_variable(tensor).kernel_argument_info == argument_info)); + ARM_COMPUTE_ERROR_ON(!(it->second.kernel_argument_info == argument_info)); return; } - // Declare variable associated with the tensor - std::stringstream ss; - ss << alias << "_t" << tensor->id(); - const auto uniq_name = ss.str(); - TensorVariable var{ tensor->id(), uniq_name, argument_info }; - if(is_interm) + const auto target = comp_group.get_tile_for_tensor(tensor); + + if(target != tensor) { - _interm_var = var; - _interm_tensors.insert(tensor->id()); + // If the tensor uses a shared tile, don't declare another variable. + it = _vars.find(target->id()); + + ARM_COMPUTE_ERROR_ON_MSG( + it == _vars.end(), + "The variable used for this tensor must have been declared."); + + _vars[tensor->id()] = it->second; } else { + // Declare variable associated with the tensor + std::stringstream ss; + ss << alias << "_t" << tensor->id(); + const auto uniq_name = ss.str(); + TensorVariable var{ tensor->id(), uniq_name, argument_info }; + _vars.emplace(tensor->id(), var); } } GpuKernelVariableTable::TensorVariable GpuKernelVariableTable::get_variable(const ITensorInfo *tensor) const { - const TensorVariable empty_var{}; - if(_vars.find(tensor->id()) != _vars.end()) - { - return _vars.at(tensor->id()); - } - if(_interm_tensors.find(tensor->id()) != _interm_tensors.end()) - { - return _interm_var; - } - return empty_var; + const auto var = _vars.at(tensor->id()); + return var; } GpuKernelVariableTable::VariableList GpuKernelVariableTable::get_variable_list(const std::vector<const ITensorInfo *> &tensors) const |