diff options
Diffstat (limited to 'src/graph/mutators/DepthConcatSubTensorMutator.cpp')
-rw-r--r-- | src/graph/mutators/DepthConcatSubTensorMutator.cpp | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp index 241c07b367..937528d143 100644 --- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp +++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp @@ -26,6 +26,7 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" #include "arm_compute/graph/Utils.h" +#include "arm_compute/graph/algorithms/TopologicalSort.h" #include "arm_compute/graph/backends/BackendRegistry.h" #include "arm_compute/graph/nodes/ConcatenateLayerNode.h" @@ -43,16 +44,26 @@ const char *DepthConcatSubTensorMutator::name() void DepthConcatSubTensorMutator::mutate(Graph &g) { + // Early exit if no Concatenation layers exist in graph + if(g.nodes(NodeType::ConcatenateLayer).empty()) + { + return; + } + + // Perform topological sort + std::vector<NodeID> topological_sorted_node_ids = dfs(g); + // Should be in reverse order of execution - for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes())) + for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids)) { - if(node && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr) + INode *node = g.node(node_id); + if(node != nullptr && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr) { // Get output tensor auto output_tensor = node->output(0); // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2) - auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get()); + auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node); if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2) { continue; @@ -84,7 +95,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) depth += input_shape.z(); } - auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get()); + auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node); dc_node->set_enabled(false); } } |