aboutsummaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-02-13 15:24:04 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:47:18 +0000
commitdde9ec96f471127e5b6d8dfaeffce024b6326f1a (patch)
tree3aa88c0dec625feeb9d17da825b87398cac6cc68 /src/graph
parente3fba0afa892c66379da1e3d3843f2155a1fb29a (diff)
downloadComputeLibrary-dde9ec96f471127e5b6d8dfaeffce024b6326f1a.tar.gz
COMPMID-909: Enabling in-place computation for batchnormalization and activation at graph level
Change-Id: I84d4a212629b21794451ab5fb5c5b187b5e28f98 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/120127 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/Graph.cpp13
-rw-r--r--src/graph/INode.cpp10
-rw-r--r--src/graph/SubGraph.cpp6
-rw-r--r--src/graph/nodes/ActivationLayer.cpp3
4 files changed, 27 insertions, 5 deletions
diff --git a/src/graph/Graph.cpp b/src/graph/Graph.cpp
index 7af313acbb..98d95904dc 100644
--- a/src/graph/Graph.cpp
+++ b/src/graph/Graph.cpp
@@ -131,6 +131,11 @@ void Graph::Private::configure(GraphHints _next_hints)
_previous_hints = _current_hints; // For the first node just assume the previous node was of the same type as this one
}
+ if(_current_node->supports_in_place())
+ {
+ _current_output = _current_input;
+ }
+
//Automatic output configuration ?
if(_current_output == nullptr)
{
@@ -152,8 +157,12 @@ void Graph::Private::configure(GraphHints _next_hints)
_ctx.hints() = _current_hints;
std::unique_ptr<arm_compute::IFunction> func = _current_node->instantiate_node(_ctx, _current_input, _current_output);
- // Allocate current input
- _current_input->allocate();
+ // If the operation is done in-place, do not allocate or it will prevent following layers from performing the configuration
+ if(!_current_node->supports_in_place())
+ {
+ // Allocate current input
+ _current_input->allocate();
+ }
// Map input if needed
if(_current_input->target() == TargetHint::OPENCL)
diff --git a/src/graph/INode.cpp b/src/graph/INode.cpp
index 582f936351..c753f66b43 100644
--- a/src/graph/INode.cpp
+++ b/src/graph/INode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,6 +39,14 @@ TargetHint INode::override_target_hint(TargetHint target_hint) const
ARM_COMPUTE_ERROR_ON(target_hint == TargetHint::OPENCL && !opencl_is_available());
return target_hint;
}
+bool INode::supports_in_place() const
+{
+ return _supports_in_place;
+}
+void INode::set_supports_in_place(bool value)
+{
+ _supports_in_place = value;
+}
GraphHints INode::node_override_hints(GraphHints hints) const
{
TargetHint target_hint = hints.target_hint();
diff --git a/src/graph/SubGraph.cpp b/src/graph/SubGraph.cpp
index f62b2617c5..b1cbb9cc95 100644
--- a/src/graph/SubGraph.cpp
+++ b/src/graph/SubGraph.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -67,6 +67,10 @@ std::unique_ptr<Graph> SubGraph::construct(const GraphContext &ctx, std::unique_
}
graph->add_tensor_object(std::move(_input));
+ // Make sure first and last nodes of the subgraph always do operations out-of-place
+ _nodes.front()->set_supports_in_place(false);
+ _nodes.back()->set_supports_in_place(false);
+
// Construct nodes
for(auto &node : _nodes)
{
diff --git a/src/graph/nodes/ActivationLayer.cpp b/src/graph/nodes/ActivationLayer.cpp
index 54f30ef777..546c42a1e5 100644
--- a/src/graph/nodes/ActivationLayer.cpp
+++ b/src/graph/nodes/ActivationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,6 +33,7 @@ using namespace arm_compute::graph;
ActivationLayer::ActivationLayer(const ActivationLayerInfo activation_info)
: _activation_info(activation_info)
{
+ set_supports_in_place(true);
}
std::unique_ptr<arm_compute::IFunction> ActivationLayer::instantiate_node(GraphContext &ctx, ITensorObject *input, ITensorObject *output)