diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-06-04 15:05:38 +0100 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-06-15 13:59:04 +0000 |
commit | 4a61653202afb018f4f259d3c144a735d73f0a20 (patch) | |
tree | 082fd42e91cc0914dcacc0746bbe3e117d74210c /src/graph | |
parent | ccd94966cc58ef5148577e71ba1a4ff5aae1f3bb (diff) | |
download | ComputeLibrary-4a61653202afb018f4f259d3c144a735d73f0a20.tar.gz |
COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel
Change-Id: I0089657dd95d7c7b8592984def8e8de1d7e6d085
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3308
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/mutators/InPlaceOperationMutator.cpp | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp index 3b06537cd9..327e985625 100644 --- a/src/graph/mutators/InPlaceOperationMutator.cpp +++ b/src/graph/mutators/InPlaceOperationMutator.cpp @@ -25,6 +25,7 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/backends/BackendRegistry.h" namespace arm_compute { @@ -42,13 +43,29 @@ IGraphMutator::MutationType InPlaceOperationMutator::type() const void InPlaceOperationMutator::mutate(Graph &g) { - std::set<NodeType> in_place_nodes = { NodeType::BatchNormalizationLayer, NodeType::ActivationLayer, NodeType::PrintLayer }; + std::set<NodeType> in_place_nodes = + { + NodeType::ActivationLayer, + NodeType::BatchNormalizationLayer, + NodeType::EltwiseLayer, + NodeType::PrintLayer, + }; // Not interested in the order of nodes for(auto &node : g.nodes()) { if(node && in_place_nodes.find(node->type()) != std::end(in_place_nodes)) { + // Validate node + backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target()); + Status status = backend.validate_node(*node); + + // If in-place computation is not supported, do nothing and go to next node + if(!bool(status)) + { + continue; + } + // Get input edge Edge *input_edge = node->input_edge(0); |