diff options
Diffstat (limited to 'arm_compute/graph/frontend')
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 27 | ||||
-rw-r--r-- | arm_compute/graph/frontend/Types.h | 1 |
2 files changed, 28 insertions, 0 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 0a1a0cf1e4..cf80dd9f4e 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -445,6 +445,33 @@ private: NormalizationLayerInfo _norm_info; }; +/** Permute Layer */ +class PermuteLayer final : public ILayer +{ +public: + /** Construct a permute layer. + * + * @param[in] perm Permutation vector. + * @param[in] layout (Optional) Data layout to assign to permuted tensor. + * If UNKNOWN then the input's layout will be used. + */ + PermuteLayer(PermutationVector perm, DataLayout layout = DataLayout::UNKNOWN) + : _perm(perm), _layout(layout) + { + } + + NodeID create_layer(IStream &s) override + { + NodeParams common_params = { name(), s.hints().target_hint }; + NodeIdxPair input = { s.tail_node(), 0 }; + return GraphBuilder::add_permute_node(s.graph(), common_params, input, _perm, _layout); + } + +private: + PermutationVector _perm; + DataLayout _layout; +}; + /** Pooling Layer */ class PoolingLayer final : public ILayer { diff --git a/arm_compute/graph/frontend/Types.h b/arm_compute/graph/frontend/Types.h index f9d4952765..8f6312f318 100644 --- a/arm_compute/graph/frontend/Types.h +++ b/arm_compute/graph/frontend/Types.h @@ -36,6 +36,7 @@ namespace frontend using graph::DataType; using graph::DataLayout; using graph::TensorShape; +using graph::PermutationVector; using graph::ActivationLayerInfo; using graph::NormalizationLayerInfo; |