aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/SplitLayerSubTensorMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/SplitLayerSubTensorMutator.cpp')
-rw-r--r--src/graph/mutators/SplitLayerSubTensorMutator.cpp17
1 files changed, 14 insertions, 3 deletions
diff --git a/src/graph/mutators/SplitLayerSubTensorMutator.cpp b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
index 2a8c029843..5f1c9c3186 100644
--- a/src/graph/mutators/SplitLayerSubTensorMutator.cpp
+++ b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
#include "arm_compute/graph/nodes/SplitLayerNode.h"
@@ -42,10 +43,20 @@ const char *SplitLayerSubTensorMutator::name()
void SplitLayerSubTensorMutator::mutate(Graph &g)
{
+ // Early exit if no Split layers exist in graph
+ if(g.nodes(NodeType::SplitLayer).empty())
+ {
+ return;
+ }
+
+ // Perform topological sort
+ std::vector<NodeID> topological_sorted_node_ids = dfs(g);
+
// Should be in reverse order of execution
- for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
{
- if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
+ INode *node = g.node(node_id);
+ if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
{
// Get output tensor
Tensor *input_tensor = node->input(0);
@@ -63,7 +74,7 @@ void SplitLayerSubTensorMutator::mutate(Graph &g)
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
- auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node.get());
+ auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
const unsigned int axis = split_node->axis();
const unsigned int num_splits = split_node->num_splits();