diff options
Diffstat (limited to 'src/graph/INode.cpp')
-rw-r--r-- | src/graph/INode.cpp | 176 |
1 files changed, 157 insertions, 19 deletions
diff --git a/src/graph/INode.cpp b/src/graph/INode.cpp index c753f66b43..c1c18e5853 100644 --- a/src/graph/INode.cpp +++ b/src/graph/INode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,33 +23,171 @@ */ #include "arm_compute/graph/INode.h" -#include "arm_compute/core/CL/OpenCL.h" -#include "arm_compute/core/Validate.h" +#include "arm_compute/core/Error.h" +#include "arm_compute/graph/Edge.h" +#include "arm_compute/graph/Graph.h" +#include "arm_compute/graph/Tensor.h" -using namespace arm_compute::graph; +namespace arm_compute +{ +namespace graph +{ +// *INDENT-OFF* +// clang-format off +INode::INode() + : _graph(nullptr), _id(EmptyNodeID), _common_params({ "", Target::UNSPECIFIED}), + _outputs(), _input_edges(), _output_edges(), _assigned_target(Target::UNSPECIFIED) +{ +} +// clang-format on +// *INDENT-ON* -TargetHint INode::override_target_hint(TargetHint target_hint) const +void INode::set_graph(Graph *g) { - if(target_hint == TargetHint::OPENCL && !opencl_is_available()) + ARM_COMPUTE_ERROR_ON(g == nullptr); + _graph = g; +} + +void INode::set_id(NodeID id) +{ + _id = id; +} + +void INode::set_common_node_parameters(NodeParams common_params) +{ + _common_params = std::move(common_params); +} + +void INode::set_requested_target(Target target) +{ + _common_params.target = target; +} + +void INode::set_assigned_target(Target target) +{ + _assigned_target = target; +} + +void INode::set_output_tensor(TensorID tid, size_t idx) +{ + if(tid != NullTensorID && (idx < _outputs.size()) && (_graph->tensor(tid) != nullptr)) { - target_hint = TargetHint::DONT_CARE; + ARM_COMPUTE_ERROR_ON(_graph == nullptr); + Tensor *updated_tensor = _graph->tensor(tid); + _outputs[idx] = tid; + + // Set tensor to all output edges of the node + for(auto &output_edge_id : _output_edges) + { + auto output_edge = _graph->edge(output_edge_id); + if(output_edge != nullptr) + { + // Unbind edge from current tensor + auto current_output_tensor = output_edge->tensor(); + current_output_tensor->unbind_edge(output_edge->id()); + + // Update tensor to edge and rebind tensor + output_edge->update_bound_tensor(updated_tensor); + updated_tensor->bind_edge(output_edge->id()); + } + } } - GraphHints hints{ target_hint }; - target_hint = node_override_hints(hints).target_hint(); - ARM_COMPUTE_ERROR_ON(target_hint == TargetHint::OPENCL && !opencl_is_available()); - return target_hint; } -bool INode::supports_in_place() const + +NodeID INode::id() const +{ + return _id; +} + +std::string INode::name() const +{ + return _common_params.name; +} + +const Graph *INode::graph() const +{ + return _graph; +} + +Graph *INode::graph() +{ + return _graph; +} + +const std::vector<TensorID> &INode::outputs() const { - return _supports_in_place; + return _outputs; } -void INode::set_supports_in_place(bool value) + +const std::vector<EdgeID> &INode::input_edges() const +{ + return _input_edges; +} + +const std::set<EdgeID> &INode::output_edges() const +{ + return _output_edges; +} + +TensorID INode::input_id(size_t idx) const +{ + ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size()); + Edge *e = _graph->edge(_input_edges[idx]); + return (e != nullptr) ? e->tensor_id() : NullTensorID; +} + +TensorID INode::output_id(size_t idx) const +{ + ARM_COMPUTE_ERROR_ON(idx >= _outputs.size()); + return _outputs[idx]; +} + +Tensor *INode::input(size_t idx) const +{ + ARM_COMPUTE_ERROR_ON(_graph == nullptr); + ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size()); + Edge *e = _graph->edge(_input_edges[idx]); + return (e != nullptr) ? e->tensor() : nullptr; +} + +Tensor *INode::output(size_t idx) const { - _supports_in_place = value; + ARM_COMPUTE_ERROR_ON(_graph == nullptr); + ARM_COMPUTE_ERROR_ON(idx >= _outputs.size()); + return _graph->tensor(_outputs[idx]); } -GraphHints INode::node_override_hints(GraphHints hints) const + +EdgeID INode::input_edge_id(size_t idx) const +{ + ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size()); + return _input_edges[idx]; +} + +Edge *INode::input_edge(size_t idx) const +{ + ARM_COMPUTE_ERROR_ON(_graph == nullptr); + ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size()); + return _graph->edge(_input_edges[idx]); +} + +size_t INode::num_inputs() const +{ + return _input_edges.size(); +} + +size_t INode::num_outputs() const +{ + return _outputs.size(); +} + +Target INode::requested_target() const +{ + return _common_params.target; +} + +Target INode::assigned_target() const { - TargetHint target_hint = hints.target_hint(); - hints.set_target_hint((target_hint == TargetHint::DONT_CARE) ? TargetHint::NEON : target_hint); - return hints; + return _assigned_target; } +} // namespace graph +} // namespace arm_compute
\ No newline at end of file |