ArmNN
 21.11
ChannelShuffleLayer.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
8 
9 namespace armnn
10 {
11 class ChannelShuffleLayer : public LayerWithParameters<ChannelShuffleDescriptor>
12 {
13 public:
15  void Accept(ILayerVisitor& visitor) const override;
17 
18  /// Creates a dynamically-allocated copy of this layer.
19  /// @param graph The graph into which this layer is being cloned
20  ChannelShuffleLayer* Clone(Graph& graph) const override;
21 
22  /// Makes a workload for the ChannelShuffle type.
23  /// @param factory The workload factory which will create the workload
24  /// @return A pointer to the created workload, or nullptr if not created.
25  virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override;
26 
27  /// Check if the input tensor shape(s)
28  /// will lead to a valid configuration of @ref ChannelShuffleLayer.
29  /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
30  void ValidateTensorShapesFromInputs() override;
31 
32  // TODO Do you need to create an InferOutputShapes function for ChannelShuffle?
33 protected:
34  ChannelShuffleLayer(const ChannelShuffleDescriptor& param, const char* name);
35 
36  ~ChannelShuffleLayer() = default;
37 };
38 
39 } // namespace
ARMNN_NO_DEPRECATE_WARN_END ChannelShuffleLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
ChannelShuffleLayer(const ChannelShuffleDescriptor &param, const char *name)
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
Copyright (c) 2021 ARM Limited and Contributors.
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of ChannelShuffleLayer.
ARMNN_NO_DEPRECATE_WARN_BEGIN void Accept(ILayerVisitor &visitor) const override
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the ChannelShuffle type.
A ChannelShuffleDescriptor for the ChannelShuffle operator.