diff options
Diffstat (limited to 'src/graph/mutators/InPlaceOperationMutator.cpp')
-rw-r--r-- | src/graph/mutators/InPlaceOperationMutator.cpp | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp index 3b06537cd9..327e985625 100644 --- a/src/graph/mutators/InPlaceOperationMutator.cpp +++ b/src/graph/mutators/InPlaceOperationMutator.cpp @@ -25,6 +25,7 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/backends/BackendRegistry.h" namespace arm_compute { @@ -42,13 +43,29 @@ IGraphMutator::MutationType InPlaceOperationMutator::type() const void InPlaceOperationMutator::mutate(Graph &g) { - std::set<NodeType> in_place_nodes = { NodeType::BatchNormalizationLayer, NodeType::ActivationLayer, NodeType::PrintLayer }; + std::set<NodeType> in_place_nodes = + { + NodeType::ActivationLayer, + NodeType::BatchNormalizationLayer, + NodeType::EltwiseLayer, + NodeType::PrintLayer, + }; // Not interested in the order of nodes for(auto &node : g.nodes()) { if(node && in_place_nodes.find(node->type()) != std::end(in_place_nodes)) { + // Validate node + backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target()); + Status status = backend.validate_node(*node); + + // If in-place computation is not supported, do nothing and go to next node + if(!bool(status)) + { + continue; + } + // Get input edge Edge *input_edge = node->input_edge(0); |