diff options
Diffstat (limited to 'src/graph/Graph.cpp')
-rw-r--r-- | src/graph/Graph.cpp | 87 |
1 files changed, 49 insertions, 38 deletions
diff --git a/src/graph/Graph.cpp b/src/graph/Graph.cpp index 525506f316..25c4577df7 100644 --- a/src/graph/Graph.cpp +++ b/src/graph/Graph.cpp @@ -46,7 +46,7 @@ public: * * @param _next_hint Device execution hint */ - void configure(Hint _next_hint); + void configure(GraphHints _next_hints); /** Sets whether to enable information print out * @@ -54,11 +54,12 @@ public: */ void set_info_enablement(bool is_enabled); + GraphContext _ctx{}; std::vector<Stage> _pipeline{}; std::vector<std::unique_ptr<Tensor>> _tensors{}; std::vector<std::unique_ptr<INode>> _nodes{}; - Hint _current_hint{ Hint::DONT_CARE }; - Hint _next_hint{ Hint::DONT_CARE }; + GraphHints _current_hints{}; + GraphHints _next_hints{}; std::unique_ptr<Tensor> _graph_input{ nullptr }; std::unique_ptr<Tensor> _graph_output{ nullptr }; std::unique_ptr<INode> _current_node{ nullptr }; @@ -66,8 +67,8 @@ public: bool _info_enabled{ false }; private: - Tensor *_current_input{ nullptr }; - Hint _previous_hint{ Hint::DONT_CARE }; + Tensor *_current_input{ nullptr }; + GraphHints _previous_hints{}; }; Graph::~Graph() //NOLINT @@ -102,7 +103,7 @@ void Graph::run() } //Finalize current node's configuration -void Graph::Private::configure(Hint _next_hint) +void Graph::Private::configure(GraphHints _next_hints) { ARM_COMPUTE_ERROR_ON(_current_node == nullptr); ARM_COMPUTE_ERROR_ON(_graph_input == nullptr); @@ -110,9 +111,9 @@ void Graph::Private::configure(Hint _next_hint) // Is it the first node of the graph ? if(_current_input == nullptr) { - _graph_input->set_target(_current_hint); - _current_input = _graph_input.get(); - _previous_hint = _current_hint; // For the first node just assume the previous node was of the same type as this one + _graph_input->set_target(_current_hints.target_hint()); + _current_input = _graph_input.get(); + _previous_hints = _current_hints; // For the first node just assume the previous node was of the same type as this one } //Automatic output configuration ? @@ -123,29 +124,31 @@ void Graph::Private::configure(Hint _next_hint) } // If either the writer or reader node needs OpenCL then use OpenCL memory: - if((_next_hint == Hint::OPENCL || _current_hint == Hint::OPENCL)) + if((_next_hints.target_hint() == TargetHint::OPENCL || _current_hints.target_hint() == TargetHint::OPENCL)) { - _current_output->set_target(Hint::OPENCL); + _current_output->set_target(TargetHint::OPENCL); } else { - _current_output->set_target(Hint::NEON); + _current_output->set_target(TargetHint::NEON); } - // Map input if needed - std::unique_ptr<arm_compute::IFunction> func = _current_node->instantiate_node(_current_hint, _current_input->tensor(), _current_output->tensor()); + // Update ctx and instantiate node + _ctx.hints() = _current_hints; + std::unique_ptr<arm_compute::IFunction> func = _current_node->instantiate_node(_ctx, _current_input->tensor(), _current_output->tensor()); _current_input->allocate(); - if(_current_input->target() == Hint::OPENCL) + // Map input if needed + if(_current_input->target() == TargetHint::OPENCL) { - if(_previous_hint == Hint::NEON) + if(_previous_hints.target_hint() == TargetHint::NEON) { - ARM_COMPUTE_ERROR_ON(_current_hint == Hint::NEON); + ARM_COMPUTE_ERROR_ON(_current_hints.target_hint() == TargetHint::NEON); _pipeline.push_back({ _current_input, _current_input, arm_compute::support::cpp14::make_unique<CLUnmap>(_current_input) }); } - if(_current_hint == Hint::NEON) + if(_current_hints.target_hint() == TargetHint::NEON) { - ARM_COMPUTE_ERROR_ON(_previous_hint == Hint::NEON); + ARM_COMPUTE_ERROR_ON(_previous_hints.target_hint() == TargetHint::NEON); _pipeline.push_back({ _current_input, _current_input, arm_compute::support::cpp14::make_unique<CLMap>(_current_input, true) }); } } @@ -154,8 +157,8 @@ void Graph::Private::configure(Hint _next_hint) _current_input = _current_output; _current_output = nullptr; - _previous_hint = _current_hint; - _current_hint = _next_hint; + std::swap(_previous_hints, _current_hints); + std::swap(_current_hints, _next_hints); } void Graph::Private::set_info_enablement(bool is_enabled) @@ -169,12 +172,13 @@ void Graph::add_node(std::unique_ptr<INode> node) ARM_COMPUTE_ERROR_ON_MSG(_pimpl->_graph_output != nullptr, "Nothing can be added after the output tensor"); //Trigger the creation of the current Node: - Hint _next_hint = node->override_hint(_pimpl->_next_hint); - ARM_COMPUTE_ERROR_ON(_next_hint == Hint::DONT_CARE); + GraphHints _next_hints = _pimpl->_next_hints; + _next_hints.set_target_hint(node->override_target_hint(_pimpl->_next_hints.target_hint())); + ARM_COMPUTE_ERROR_ON(_next_hints.target_hint() == TargetHint::DONT_CARE); if(_pimpl->_current_node) { //Finalize the previous Node: - _pimpl->configure(_pimpl->_next_hint); + _pimpl->configure(_pimpl->_next_hints); if(_pimpl->_info_enabled) { @@ -183,8 +187,8 @@ void Graph::add_node(std::unique_ptr<INode> node) } else { - // If that's the first node then use the same Hint before and after the node. - _pimpl->_current_hint = _next_hint; + // If that's the first node then use the same TargetHint before and after the node. + _pimpl->_current_hints = _next_hints; } if(_pimpl->_current_node) { @@ -192,15 +196,6 @@ void Graph::add_node(std::unique_ptr<INode> node) } _pimpl->_current_node = std::move(node); } -void Graph::set_hint(Hint hint) -{ - _pimpl->_next_hint = hint; -} - -void Graph::set_info_enablement(bool is_enabled) -{ - _pimpl->set_info_enablement(is_enabled); -} //Add a tensor with an Accessor (i.e either the input or output of the graph) void Graph::add_tensor(std::unique_ptr<Tensor> tensor) @@ -221,7 +216,7 @@ void Graph::add_tensor(std::unique_ptr<Tensor> tensor) _pimpl->_current_output = _pimpl->_graph_output.get(); // Finalize the graph by configuring the last Node of the graph: - _pimpl->configure(_pimpl->_current_hint); // Ignore _next_hint as this is the last node, and just use the same hint as before this node. + _pimpl->configure(_pimpl->_current_hints); // Ignore _next_hint as this is the last node, and just use the same hint as before this node. _pimpl->_graph_output->allocate(); } } @@ -236,6 +231,16 @@ void Graph::set_temp(TensorInfo &&tmp) _pimpl->_current_output = _pimpl->_tensors.back().get(); } +void Graph::set_info_enablement(bool is_enabled) +{ + _pimpl->set_info_enablement(is_enabled); +} + +GraphHints &Graph::hints() +{ + return _pimpl->_next_hints; +} + Graph &arm_compute::graph::operator<<(Graph &graph, TensorInfo &&info) { graph.set_temp(std::move(info)); @@ -248,8 +253,14 @@ Graph &arm_compute::graph::operator<<(Graph &graph, Tensor &&tensor) return graph; } -Graph &arm_compute::graph::operator<<(Graph &graph, Hint hint) +Graph &arm_compute::graph::operator<<(Graph &graph, TargetHint target_hint) +{ + graph.hints().set_target_hint(target_hint); + return graph; +} + +Graph &arm_compute::graph::operator<<(Graph &graph, ConvolutionMethodHint conv_method_hint) { - graph.set_hint(hint); + graph.hints().set_convolution_method_hint(conv_method_hint); return graph; } |