aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/InPlaceOperationMutator.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-04 15:05:38 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-15 13:59:04 +0000
commit4a61653202afb018f4f259d3c144a735d73f0a20 (patch)
tree082fd42e91cc0914dcacc0746bbe3e117d74210c /src/graph/mutators/InPlaceOperationMutator.cpp
parentccd94966cc58ef5148577e71ba1a4ff5aae1f3bb (diff)
downloadComputeLibrary-4a61653202afb018f4f259d3c144a735d73f0a20.tar.gz
COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel
Change-Id: I0089657dd95d7c7b8592984def8e8de1d7e6d085 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3308 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/mutators/InPlaceOperationMutator.cpp')
-rw-r--r--src/graph/mutators/InPlaceOperationMutator.cpp19
1 files changed, 18 insertions, 1 deletions
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp
index 3b06537cd9..327e985625 100644
--- a/src/graph/mutators/InPlaceOperationMutator.cpp
+++ b/src/graph/mutators/InPlaceOperationMutator.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/backends/BackendRegistry.h"
namespace arm_compute
{
@@ -42,13 +43,29 @@ IGraphMutator::MutationType InPlaceOperationMutator::type() const
void InPlaceOperationMutator::mutate(Graph &g)
{
- std::set<NodeType> in_place_nodes = { NodeType::BatchNormalizationLayer, NodeType::ActivationLayer, NodeType::PrintLayer };
+ std::set<NodeType> in_place_nodes =
+ {
+ NodeType::ActivationLayer,
+ NodeType::BatchNormalizationLayer,
+ NodeType::EltwiseLayer,
+ NodeType::PrintLayer,
+ };
// Not interested in the order of nodes
for(auto &node : g.nodes())
{
if(node && in_place_nodes.find(node->type()) != std::end(in_place_nodes))
{
+ // Validate node
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
+
+ // If in-place computation is not supported, do nothing and go to next node
+ if(!bool(status))
+ {
+ continue;
+ }
+
// Get input edge
Edge *input_edge = node->input_edge(0);