From 2a2db590fd179dcb8e1a575293cd2b887e2dc246 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 15 Aug 2018 12:14:46 +0100 Subject: COMPMID-1505: Add native grouping support at graph level Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Michele DiGiorgio --- src/graph/mutators/SplitLayerSubTensorMutator.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'src/graph/mutators/SplitLayerSubTensorMutator.cpp') diff --git a/src/graph/mutators/SplitLayerSubTensorMutator.cpp b/src/graph/mutators/SplitLayerSubTensorMutator.cpp index 2a8c029843..5f1c9c3186 100644 --- a/src/graph/mutators/SplitLayerSubTensorMutator.cpp +++ b/src/graph/mutators/SplitLayerSubTensorMutator.cpp @@ -25,6 +25,7 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/algorithms/TopologicalSort.h" #include "arm_compute/graph/backends/BackendRegistry.h" #include "arm_compute/graph/nodes/SplitLayerNode.h" @@ -42,10 +43,20 @@ const char *SplitLayerSubTensorMutator::name() void SplitLayerSubTensorMutator::mutate(Graph &g) { + // Early exit if no Split layers exist in graph + if(g.nodes(NodeType::SplitLayer).empty()) + { + return; + } + + // Perform topological sort + std::vector topological_sorted_node_ids = dfs(g); + // Should be in reverse order of execution - for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes())) + for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids)) { - if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr) + INode *node = g.node(node_id); + if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr) { // Get output tensor Tensor *input_tensor = node->input(0); @@ -63,7 +74,7 @@ void SplitLayerSubTensorMutator::mutate(Graph &g) ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : " << node->id() << " and name : " << node->name() << std::endl); - auto *split_node = arm_compute::utils::cast::polymorphic_downcast(node.get()); + auto *split_node = arm_compute::utils::cast::polymorphic_downcast(node); const unsigned int axis = split_node->axis(); const unsigned int num_splits = split_node->num_splits(); -- cgit v1.2.1