aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/InPlaceOperationMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/InPlaceOperationMutator.cpp')
-rw-r--r--src/graph/mutators/InPlaceOperationMutator.cpp19
1 files changed, 18 insertions, 1 deletions
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp
index 3b06537cd9..327e985625 100644
--- a/src/graph/mutators/InPlaceOperationMutator.cpp
+++ b/src/graph/mutators/InPlaceOperationMutator.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/backends/BackendRegistry.h"
namespace arm_compute
{
@@ -42,13 +43,29 @@ IGraphMutator::MutationType InPlaceOperationMutator::type() const
void InPlaceOperationMutator::mutate(Graph &g)
{
- std::set<NodeType> in_place_nodes = { NodeType::BatchNormalizationLayer, NodeType::ActivationLayer, NodeType::PrintLayer };
+ std::set<NodeType> in_place_nodes =
+ {
+ NodeType::ActivationLayer,
+ NodeType::BatchNormalizationLayer,
+ NodeType::EltwiseLayer,
+ NodeType::PrintLayer,
+ };
// Not interested in the order of nodes
for(auto &node : g.nodes())
{
if(node && in_place_nodes.find(node->type()) != std::end(in_place_nodes))
{
+ // Validate node
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
+
+ // If in-place computation is not supported, do nothing and go to next node
+ if(!bool(status))
+ {
+ continue;
+ }
+
// Get input edge
Edge *input_edge = node->input_edge(0);