diff options
Diffstat (limited to 'arm_compute/graph/nodes')
-rw-r--r-- | arm_compute/graph/nodes/ConcatenateLayerNode.h (renamed from arm_compute/graph/nodes/DepthConcatenateLayerNode.h) | 37 | ||||
-rw-r--r-- | arm_compute/graph/nodes/ConvolutionLayerNode.h | 6 | ||||
-rw-r--r-- | arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/nodes/Nodes.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/nodes/NodesFwd.h | 2 |
5 files changed, 29 insertions, 20 deletions
diff --git a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h b/arm_compute/graph/nodes/ConcatenateLayerNode.h index ffdec709ef..20c8523752 100644 --- a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h +++ b/arm_compute/graph/nodes/ConcatenateLayerNode.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ -#define __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ +#ifndef __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ #include "arm_compute/graph/INode.h" @@ -30,30 +30,31 @@ namespace arm_compute { namespace graph { -/** Depth Concatenation Layer node */ -class DepthConcatenateLayerNode final : public INode +/** Concatenation Layer node */ +class ConcatenateLayerNode final : public INode { public: /** Constructor * * @param[in] total_nodes Number of nodes that will get concatenated + * @param[in] axis Concatenation axis */ - DepthConcatenateLayerNode(unsigned int total_nodes); - /** Computes depth concatenations output descriptor + ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis); + /** Computes concatenations output descriptor * * @param[in] input_descriptors Input descriptors + * @param[in] axis Concatenation axis * * @return Expected output descriptor */ - static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors); + static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors, DataLayoutDimension axis); /** Disables or not the depth concatenate node * - * @warning This is used when depth concatenate is performed with sub-tensors, - * where this node is used as a placeholder. + * @warning This is used when concatenate is performed using sub-tensors, where this node is used as a placeholder. * - * @param[in] is_enabled If true a backend function is created to perform the depth concatenation (involves copying), - * while if false, no function is created and we assume that subtensors are properly set to simulate - * a no copy operation. + * @param[in] is_enabled If true a backend function is created to perform the concatenation (involves copying), + * while if false, no function is created and we assume that sub-tensors are properly set to simulate + * a zero copy operation. */ void set_enabled(bool is_enabled); /** Enabled parameter accessor @@ -61,6 +62,11 @@ public: * @return True if a backend function is to be created else false */ bool is_enabled() const; + /** Concatenation axis parameter accessor + * + * @return Concatenation axis + */ + DataLayoutDimension concatenation_axis() const; // Inherited overridden methods: NodeType type() const override; @@ -69,9 +75,10 @@ public: void accept(INodeVisitor &v) override; private: - unsigned int _total_nodes; - bool _is_enabled; + unsigned int _total_nodes; + DataLayoutDimension _axis; + bool _is_enabled; }; } // namespace graph } // namespace arm_compute -#endif /* __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ */ +#endif /* __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ */ diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h index aca60283d7..4299be6bb5 100644 --- a/arm_compute/graph/nodes/ConvolutionLayerNode.h +++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h @@ -41,8 +41,10 @@ public: * @param[in] fast_math_hint (Optional) Fast math hint * @param[in] out_quant_info (Optional) Output quantization info */ - ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT, FastMathHint fast_math_hint = FastMathHint::DISABLED, - QuantizationInfo out_quant_info = QuantizationInfo()); + ConvolutionLayerNode(PadStrideInfo info, + ConvolutionMethod method = ConvolutionMethod::Default, + FastMathHint fast_math_hint = FastMathHint::Disabled, + QuantizationInfo out_quant_info = QuantizationInfo()); /** Sets the convolution layer method to use * * @param[in] method Method to use for convolution diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h index df6f456ac9..1a173c5421 100644 --- a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h +++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h @@ -39,7 +39,7 @@ public: * @param[in] info Convolution layer attributes * @param[in] method Depthwise convolution method to use */ - DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::DEFAULT); + DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::Default); /** Sets the depthwise convolution method to use * * @param[in] method Depthwise convolution method to use diff --git a/arm_compute/graph/nodes/Nodes.h b/arm_compute/graph/nodes/Nodes.h index 97aa191916..f2e751e15f 100644 --- a/arm_compute/graph/nodes/Nodes.h +++ b/arm_compute/graph/nodes/Nodes.h @@ -27,10 +27,10 @@ #include "arm_compute/graph/nodes/ActivationLayerNode.h" #include "arm_compute/graph/nodes/BatchNormalizationLayerNode.h" #include "arm_compute/graph/nodes/ChannelShuffleLayerNode.h" +#include "arm_compute/graph/nodes/ConcatenateLayerNode.h" #include "arm_compute/graph/nodes/ConstNode.h" #include "arm_compute/graph/nodes/ConvolutionLayerNode.h" #include "arm_compute/graph/nodes/DeconvolutionLayerNode.h" -#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h" #include "arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h" #include "arm_compute/graph/nodes/DummyNode.h" #include "arm_compute/graph/nodes/EltwiseLayerNode.h" diff --git a/arm_compute/graph/nodes/NodesFwd.h b/arm_compute/graph/nodes/NodesFwd.h index 05979d796c..a0a9146dc4 100644 --- a/arm_compute/graph/nodes/NodesFwd.h +++ b/arm_compute/graph/nodes/NodesFwd.h @@ -33,10 +33,10 @@ class INode; class ActivationLayerNode; class BatchNormalizationLayerNode; class ChannelShuffleLayerNode; +class ConcatenateLayerNode; class ConstNode; class ConvolutionLayerNode; class DeconvolutionLayerNode; -class DepthConcatenateLayerNode; class DepthwiseConvolutionLayerNode; class DummyNode; class EltwiseLayerNode; |