aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/nodes
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-20 13:23:44 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite2220551b7a64b929650ba9a60529c31e70c13c5 (patch)
tree5d609887f15b4392cdade7bb388710ceafc62260 /arm_compute/graph/nodes
parenteff8d95991205e874091576e2d225f63246dd0bb (diff)
downloadComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/nodes')
-rw-r--r--arm_compute/graph/nodes/ConcatenateLayerNode.h (renamed from arm_compute/graph/nodes/DepthConcatenateLayerNode.h)37
-rw-r--r--arm_compute/graph/nodes/ConvolutionLayerNode.h6
-rw-r--r--arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h2
-rw-r--r--arm_compute/graph/nodes/Nodes.h2
-rw-r--r--arm_compute/graph/nodes/NodesFwd.h2
5 files changed, 29 insertions, 20 deletions
diff --git a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h b/arm_compute/graph/nodes/ConcatenateLayerNode.h
index ffdec709ef..20c8523752 100644
--- a/arm_compute/graph/nodes/DepthConcatenateLayerNode.h
+++ b/arm_compute/graph/nodes/ConcatenateLayerNode.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__
-#define __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__
+#ifndef __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__
+#define __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__
#include "arm_compute/graph/INode.h"
@@ -30,30 +30,31 @@ namespace arm_compute
{
namespace graph
{
-/** Depth Concatenation Layer node */
-class DepthConcatenateLayerNode final : public INode
+/** Concatenation Layer node */
+class ConcatenateLayerNode final : public INode
{
public:
/** Constructor
*
* @param[in] total_nodes Number of nodes that will get concatenated
+ * @param[in] axis Concatenation axis
*/
- DepthConcatenateLayerNode(unsigned int total_nodes);
- /** Computes depth concatenations output descriptor
+ ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis);
+ /** Computes concatenations output descriptor
*
* @param[in] input_descriptors Input descriptors
+ * @param[in] axis Concatenation axis
*
* @return Expected output descriptor
*/
- static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors);
+ static TensorDescriptor compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors, DataLayoutDimension axis);
/** Disables or not the depth concatenate node
*
- * @warning This is used when depth concatenate is performed with sub-tensors,
- * where this node is used as a placeholder.
+ * @warning This is used when concatenate is performed using sub-tensors, where this node is used as a placeholder.
*
- * @param[in] is_enabled If true a backend function is created to perform the depth concatenation (involves copying),
- * while if false, no function is created and we assume that subtensors are properly set to simulate
- * a no copy operation.
+ * @param[in] is_enabled If true a backend function is created to perform the concatenation (involves copying),
+ * while if false, no function is created and we assume that sub-tensors are properly set to simulate
+ * a zero copy operation.
*/
void set_enabled(bool is_enabled);
/** Enabled parameter accessor
@@ -61,6 +62,11 @@ public:
* @return True if a backend function is to be created else false
*/
bool is_enabled() const;
+ /** Concatenation axis parameter accessor
+ *
+ * @return Concatenation axis
+ */
+ DataLayoutDimension concatenation_axis() const;
// Inherited overridden methods:
NodeType type() const override;
@@ -69,9 +75,10 @@ public:
void accept(INodeVisitor &v) override;
private:
- unsigned int _total_nodes;
- bool _is_enabled;
+ unsigned int _total_nodes;
+ DataLayoutDimension _axis;
+ bool _is_enabled;
};
} // namespace graph
} // namespace arm_compute
-#endif /* __ARM_COMPUTE_GRAPH_DEPTH_CONCATENATE_LAYER_NODE_H__ */
+#endif /* __ARM_COMPUTE_GRAPH_CONCATENATE_LAYER_NODE_H__ */
diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h
index aca60283d7..4299be6bb5 100644
--- a/arm_compute/graph/nodes/ConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h
@@ -41,8 +41,10 @@ public:
* @param[in] fast_math_hint (Optional) Fast math hint
* @param[in] out_quant_info (Optional) Output quantization info
*/
- ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT, FastMathHint fast_math_hint = FastMathHint::DISABLED,
- QuantizationInfo out_quant_info = QuantizationInfo());
+ ConvolutionLayerNode(PadStrideInfo info,
+ ConvolutionMethod method = ConvolutionMethod::Default,
+ FastMathHint fast_math_hint = FastMathHint::Disabled,
+ QuantizationInfo out_quant_info = QuantizationInfo());
/** Sets the convolution layer method to use
*
* @param[in] method Method to use for convolution
diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
index df6f456ac9..1a173c5421 100644
--- a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
@@ -39,7 +39,7 @@ public:
* @param[in] info Convolution layer attributes
* @param[in] method Depthwise convolution method to use
*/
- DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::DEFAULT);
+ DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::Default);
/** Sets the depthwise convolution method to use
*
* @param[in] method Depthwise convolution method to use
diff --git a/arm_compute/graph/nodes/Nodes.h b/arm_compute/graph/nodes/Nodes.h
index 97aa191916..f2e751e15f 100644
--- a/arm_compute/graph/nodes/Nodes.h
+++ b/arm_compute/graph/nodes/Nodes.h
@@ -27,10 +27,10 @@
#include "arm_compute/graph/nodes/ActivationLayerNode.h"
#include "arm_compute/graph/nodes/BatchNormalizationLayerNode.h"
#include "arm_compute/graph/nodes/ChannelShuffleLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/graph/nodes/ConstNode.h"
#include "arm_compute/graph/nodes/ConvolutionLayerNode.h"
#include "arm_compute/graph/nodes/DeconvolutionLayerNode.h"
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
#include "arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h"
#include "arm_compute/graph/nodes/DummyNode.h"
#include "arm_compute/graph/nodes/EltwiseLayerNode.h"
diff --git a/arm_compute/graph/nodes/NodesFwd.h b/arm_compute/graph/nodes/NodesFwd.h
index 05979d796c..a0a9146dc4 100644
--- a/arm_compute/graph/nodes/NodesFwd.h
+++ b/arm_compute/graph/nodes/NodesFwd.h
@@ -33,10 +33,10 @@ class INode;
class ActivationLayerNode;
class BatchNormalizationLayerNode;
class ChannelShuffleLayerNode;
+class ConcatenateLayerNode;
class ConstNode;
class ConvolutionLayerNode;
class DeconvolutionLayerNode;
-class DepthConcatenateLayerNode;
class DepthwiseConvolutionLayerNode;
class DummyNode;
class EltwiseLayerNode;