aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/DepthConcatSubTensorMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/DepthConcatSubTensorMutator.cpp')
-rw-r--r--src/graph/mutators/DepthConcatSubTensorMutator.cpp19
1 files changed, 15 insertions, 4 deletions
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index 241c07b367..937528d143 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
#include "arm_compute/graph/Utils.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
@@ -43,16 +44,26 @@ const char *DepthConcatSubTensorMutator::name()
void DepthConcatSubTensorMutator::mutate(Graph &g)
{
+ // Early exit if no Concatenation layers exist in graph
+ if(g.nodes(NodeType::ConcatenateLayer).empty())
+ {
+ return;
+ }
+
+ // Perform topological sort
+ std::vector<NodeID> topological_sorted_node_ids = dfs(g);
+
// Should be in reverse order of execution
- for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
{
- if(node && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
+ INode *node = g.node(node_id);
+ if(node != nullptr && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
{
// Get output tensor
auto output_tensor = node->output(0);
// Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2)
- auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
+ auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node);
if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2)
{
continue;
@@ -84,7 +95,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
depth += input_shape.z();
}
- auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
+ auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node);
dc_node->set_enabled(false);
}
}