diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-07-20 13:23:44 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | e2220551b7a64b929650ba9a60529c31e70c13c5 (patch) | |
tree | 5d609887f15b4392cdade7bb388710ceafc62260 /arm_compute/graph/nodes | |
parent | eff8d95991205e874091576e2d225f63246dd0bb (diff) | |
download | ComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz |
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/nodes')
-rw-r--r-- | arm_compute/graph/nodes/ConcatenateLayerNode.h (renamed from arm_compute/graph/nodes/DepthConcatenateLayerNode.h) | 37 | ||||
-rw-r--r-- | arm_compute/graph/nodes/ConvolutionLayerNode.h | 6 | ||||
-rw-r--r-- | arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/nodes/Nodes.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/nodes/NodesFwd.h | 2 |
5 files changed, 29 insertions, 20 deletions
diff --git a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h b/arm_compute/graph/nodes/ConcatenateLayerNode.h index ffdec709ef..20c8523752 100644 --- a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h +++ b/arm_compute/graph/nodes/ConcatenateLayerNode.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ -#define __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ +#ifndef __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ #include "arm_compute/graph/INode.h" @@ -30,30 +30,31 @@ namespace arm_compute { namespace graph { -/** Depth Concatenation Layer node */ -class DepthConcatenateLayerNode final : public INode +/** Concatenation Layer node */ +class ConcatenateLayerNode final : public INode { public: /** Constructor * * @param[in] total_nodes Number of nodes that will get concatenated + * @param[in] axis Concatenation axis */ - DepthConcatenateLayerNode(unsigned int total_nodes); - /** Computes depth concatenations output descriptor + ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis); + /** Computes concatenations output descriptor * * @param[in] input_descriptors Input descriptors + * @param[in] axis Concatenation axis * * @return Expected output descriptor */ - static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors); + static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors, DataLayoutDimension axis); /** Disables or not the depth concatenate node * - * @warning This is used when depth concatenate is performed with sub-tensors, - * where this node is used as a placeholder. + * @warning This is used when concatenate is performed using sub-tensors, where this node is used as a placeholder. * - * @param[in] is_enabled If true a backend function is created to perform the depth concatenation (involves copying), - * while if false, no function is created and we assume that subtensors are properly set to simulate - * a no copy operation. + * @param[in] is_enabled If true a backend function is created to perform the concatenation (involves copying), + * while if false, no function is created and we assume that sub-tensors are properly set to simulate + * a zero copy operation. */ void set_enabled(bool is_enabled); /** Enabled parameter accessor @@ -61,6 +62,11 @@ public: * @return True if a backend function is to be created else false */ bool is_enabled() const; + /** Concatenation axis parameter accessor + * + * @return Concatenation axis + */ + DataLayoutDimension concatenation_axis() const; // Inherited overridden methods: NodeType type() const override; @@ -69,9 +75,10 @@ public: void accept(INodeVisitor &v) override; private: - unsigned int _total_nodes; - bool _is_enabled; + unsigned int _total_nodes; + DataLayoutDimension _axis; + bool _is_enabled; }; } // namespace graph } // namespace arm_compute -#endif /* __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ */ +#endif /* __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ */ diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h index aca60283d7..4299be6bb5 100644 --- a/arm_compute/graph/nodes/ConvolutionLayerNode.h +++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h @@ -41,8 +41,10 @@ public: * @param[in] fast_math_hint (Optional) Fast math hint * @param[in] out_quant_info (Optional) Output quantization info */ - ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT, FastMathHint fast_math_hint = FastMathHint::DISABLED, - QuantizationInfo out_quant_info = QuantizationInfo()); + ConvolutionLayerNode(PadStrideInfo info, + ConvolutionMethod method = ConvolutionMethod::Default, + FastMathHint fast_math_hint = FastMathHint::Disabled, + QuantizationInfo out_quant_info = QuantizationInfo()); /** Sets the convolution layer method to use * * @param[in] method Method to use for convolution diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h index df6f456ac9..1a173c5421 100644 --- a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h +++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h @@ -39,7 +39,7 @@ public: * @param[in] info Convolution layer attributes * @param[in] method Depthwise convolution method to use */ - DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::DEFAULT); + DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::Default); /** Sets the depthwise convolution method to use * * @param[in] method Depthwise convolution method to use diff --git a/arm_compute/graph/nodes/Nodes.h b/arm_compute/graph/nodes/Nodes.h index 97aa191916..f2e751e15f 100644 --- a/arm_compute/graph/nodes/Nodes.h +++ b/arm_compute/graph/nodes/Nodes.h @@ -27,10 +27,10 @@ #include "arm_compute/graph/nodes/ActivationLayerNode.h" #include "arm_compute/graph/nodes/BatchNormalizationLayerNode.h" #include "arm_compute/graph/nodes/ChannelShuffleLayerNode.h" +#include "arm_compute/graph/nodes/ConcatenateLayerNode.h" #include "arm_compute/graph/nodes/ConstNode.h" #include "arm_compute/graph/nodes/ConvolutionLayerNode.h" #include "arm_compute/graph/nodes/DeconvolutionLayerNode.h" -#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h" #include "arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h" #include "arm_compute/graph/nodes/DummyNode.h" #include "arm_compute/graph/nodes/EltwiseLayerNode.h" diff --git a/arm_compute/graph/nodes/NodesFwd.h b/arm_compute/graph/nodes/NodesFwd.h index 05979d796c..a0a9146dc4 100644 --- a/arm_compute/graph/nodes/NodesFwd.h +++ b/arm_compute/graph/nodes/NodesFwd.h @@ -33,10 +33,10 @@ class INode; class ActivationLayerNode; class BatchNormalizationLayerNode; class ChannelShuffleLayerNode; +class ConcatenateLayerNode; class ConstNode; class ConvolutionLayerNode; class DeconvolutionLayerNode; -class DepthConcatenateLayerNode; class DepthwiseConvolutionLayerNode; class DummyNode; class EltwiseLayerNode; |