aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/frontend/Layers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r--arm_compute/graph/frontend/Layers.h54
1 files changed, 51 insertions, 3 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 7ed448e3f2..78a3f20f1f 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -225,7 +225,27 @@ public:
*/
template <typename... Ts>
ConcatLayer(SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
- : _sub_streams()
+ : _sub_streams(), _axis(DataLayoutDimension::CHANNEL)
+ {
+ _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1)));
+ _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2)));
+
+ utility::for_each([&](SubStream && sub_stream)
+ {
+ _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
+ },
+ std::move(rest_sub_streams)...);
+ }
+ /** Construct a concatenation layer
+ *
+ * @param[in] axis Axis over the concatenation will be performed
+ * @param[in] sub_stream1 First graph branch
+ * @param[in] sub_stream2 Second graph branch
+ * @param[in] rest_sub_streams Rest sub-graph branches
+ */
+ template <typename... Ts>
+ ConcatLayer(DataLayoutDimension axis, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
+ : _sub_streams(), _axis(axis)
{
_sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1)));
_sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2)));
@@ -242,7 +262,7 @@ public:
*/
template <typename... Ts>
ConcatLayer(SubStream &&sub_stream)
- : _sub_streams()
+ : _sub_streams(), _axis(DataLayoutDimension::CHANNEL)
{
_sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
}
@@ -269,13 +289,14 @@ public:
}
}
}
- nid = GraphBuilder::add_concatenate_node(s.graph(), common_params, nodes, DataLayoutDimension::CHANNEL);
+ nid = GraphBuilder::add_concatenate_node(s.graph(), common_params, nodes, _axis);
}
return nid;
}
private:
std::vector<std::unique_ptr<SubStream>> _sub_streams;
+ DataLayoutDimension _axis;
};
/** Convolution Layer */
@@ -724,6 +745,33 @@ private:
PoolingLayerInfo _pool_info;
};
+/** PriorBox Layer */
+class PriorBoxLayer final : public ILayer
+{
+public:
+ /** Construct a priorbox layer.
+ *
+ * @param[in] sub_stream First graph sub-stream
+ * @param[in] prior_info PriorBox parameters.
+ */
+ PriorBoxLayer(SubStream &&sub_stream, PriorBoxLayerInfo prior_info)
+ : _ss(std::move(sub_stream)), _prior_info(prior_info)
+ {
+ }
+
+ NodeID create_layer(IStream &s) override
+ {
+ NodeParams common_params = { name(), s.hints().target_hint };
+ NodeIdxPair input0 = { s.tail_node(), 0 };
+ NodeIdxPair input1 = { _ss.tail_node(), 0 };
+ return GraphBuilder::add_priorbox_node(s.graph(), common_params, input0, input1, _prior_info);
+ }
+
+private:
+ SubStream _ss;
+ PriorBoxLayerInfo _prior_info;
+};
+
/** Reorg Layer */
class ReorgLayer final : public ILayer
{