aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/InPlaceOperationMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/InPlaceOperationMutator.cpp')
-rw-r--r--src/graph/mutators/InPlaceOperationMutator.cpp43
1 files changed, 42 insertions, 1 deletions
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp
index 394dba84ff..7c75149eb6 100644
--- a/src/graph/mutators/InPlaceOperationMutator.cpp
+++ b/src/graph/mutators/InPlaceOperationMutator.cpp
@@ -30,6 +30,47 @@ namespace arm_compute
{
namespace graph
{
+namespace
+{
+// Check if the output edges of the parent node are separate tensors. If not,
+// it means the same output is connected to multiple nodes and computations on
+// these nodes cannot be done in-place.
+bool output_edges_are_separate_tensors(Graph &g, const Edge *input_edge)
+{
+ const auto parent_node = input_edge->producer();
+ const auto input_tensor = input_edge->tensor();
+ const auto input_edge_id = input_edge->id();
+
+ if(parent_node == nullptr)
+ {
+ return false;
+ }
+
+ const auto output_edges = parent_node->output_edges();
+
+ // If the output is connected to only one edge, then computations can
+ // be done in-place.
+ if(output_edges.size() == 1)
+ {
+ return true;
+ }
+
+ return std::all_of(output_edges.begin(),
+ output_edges.end(),
+ [&](const EdgeID & edge_id)
+ {
+ // Skip check on current input edge
+ if(edge_id == input_edge_id)
+ {
+ return true;
+ }
+
+ auto edge = g.edge(edge_id);
+ return edge->tensor() != input_tensor;
+ });
+}
+} // namespace
+
const char *InPlaceOperationMutator::name()
{
return "InPlaceOperationMutator";
@@ -60,7 +101,7 @@ void InPlaceOperationMutator::mutate(Graph &g)
Edge *input_edge = node->input_edge(0);
// Check if parent has a single output if yes then force in place calculation else not
- if((input_edge != nullptr) && (input_edge->producer() != nullptr) && (input_edge->producer()->output_edges().size() == 1))
+ if((input_edge != nullptr) && output_edges_are_separate_tensors(g, input_edge))
{
// Get current and new output tensors
auto current_output_tensor = node->output(0);