diff options
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 67dc06c878..4e6f0eee2d 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -1038,6 +1038,92 @@ private: float _beta; }; +/** Stack Layer */ +class StackLayer final : public ILayer +{ +public: + /** Construct a concatenation layer + * + * @param[in] sub_stream1 First graph branch + * @param[in] sub_stream2 Second graph branch + * @param[in] rest_sub_streams Rest sub-graph branches + */ + template <typename... Ts> + StackLayer(SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams) + : _sub_streams(), _axis(0) + { + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1))); + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2))); + + utility::for_each([&](SubStream && sub_stream) + { + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream))); + }, + std::move(rest_sub_streams)...); + } + /** Construct a concatenation layer + * + * @param[in] axis Stack layer axis along which to stack the inputs + * @param[in] sub_stream1 First graph branch + * @param[in] sub_stream2 Second graph branch + * @param[in] rest_sub_streams Rest sub-graph branches + */ + template <typename... Ts> + StackLayer(int axis, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams) + : _sub_streams(), _axis(axis) + { + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1))); + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2))); + + utility::for_each([&](SubStream && sub_stream) + { + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream))); + }, + std::move(rest_sub_streams)...); + } + /** Construct a concat layer + * + * @param[in] sub_stream Sub-stream + */ + template <typename... Ts> + StackLayer(SubStream &&sub_stream) + : _sub_streams(), _axis(0) + { + _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream))); + } + NodeID create_layer(IStream &s) override + { + NodeID nid = EmptyNodeID; + NodeParams common_params = { name(), s.hints().target_hint }; + if(_sub_streams.size() == 1 && _sub_streams.at(0) != nullptr) + { + nid = _sub_streams[0]->tail_node(); + } + else + { + // Collect tail nodes and stack + std::vector<NodeIdxPair> nodes; + for(auto &ss : _sub_streams) + { + if(ss && (ss->tail_node() != EmptyNodeID)) + { + const auto tail_node = s.graph().node(ss->tail_node()); + if(tail_node != nullptr && tail_node->type() != NodeType::Output) + { + nodes.push_back({ ss->tail_node(), 0 }); + } + } + } + nid = GraphBuilder::add_stack_node(s.graph(), common_params, nodes, _axis); + } + return nid; + } + +private: + std::vector<std::unique_ptr<SubStream>> _sub_streams; + int _axis; +}; + /** Upsample Layer */ class UpsampleLayer final : public ILayer { |