aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-03-08 16:01:29 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commitee33ea5a6e1aa0faac1cc8b5a269bd4f89854821 (patch)
tree0baf159ae4a61d07cc765ad6bb1a2fb42c403081
parente86a09fe4c5aa9037787e13ee55cba2b049d5ea5 (diff)
downloadComputeLibrary-ee33ea5a6e1aa0faac1cc8b5a269bd4f89854821.tar.gz
COMPMID-996: Add support for grouped convolution.
Change-Id: I279e29ce20b3dde57445264dc11491f127b44d70 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124429 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/graph2/GraphBuilder.h32
-rw-r--r--arm_compute/graph2/IDeviceBackend.h9
-rw-r--r--arm_compute/graph2/INodeVisitor.h9
-rw-r--r--arm_compute/graph2/Types.h2
-rw-r--r--arm_compute/graph2/backends/CL/CLDeviceBackend.h2
-rw-r--r--arm_compute/graph2/backends/CL/CLSubTensorHandle.h3
-rw-r--r--arm_compute/graph2/backends/NEON/NEDeviceBackend.h2
-rw-r--r--arm_compute/graph2/backends/NEON/NESubTensorHandle.h3
-rw-r--r--arm_compute/graph2/frontend/Layers.h3
-rw-r--r--arm_compute/graph2/mutators/GraphMutators.h1
-rw-r--r--arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h46
-rw-r--r--arm_compute/graph2/nodes/Nodes.h1
-rw-r--r--arm_compute/graph2/nodes/NodesFwd.h1
-rw-r--r--arm_compute/graph2/nodes/SplitLayerNode.h79
-rw-r--r--examples/graph_alexnet.cpp32
-rw-r--r--src/graph2/GraphBuilder.cpp135
-rw-r--r--src/graph2/Utils.cpp1
-rw-r--r--src/graph2/backends/CL/CLDeviceBackend.cpp4
-rw-r--r--src/graph2/backends/CL/CLSubTensorHandle.cpp4
-rw-r--r--src/graph2/backends/NEON/NEDeviceBackend.cpp4
-rw-r--r--src/graph2/backends/NEON/NESubTensorHandle.cpp4
-rw-r--r--src/graph2/mutators/DepthConcatSubTensorMutator.cpp2
-rw-r--r--src/graph2/mutators/SplitLayerSubTensorMutator.cpp89
-rw-r--r--src/graph2/nodes/SplitLayerNode.cpp117
24 files changed, 484 insertions, 101 deletions
diff --git a/arm_compute/graph2/GraphBuilder.h b/arm_compute/graph2/GraphBuilder.h
index f92746a603..f9fb251fc5 100644
--- a/arm_compute/graph2/GraphBuilder.h
+++ b/arm_compute/graph2/GraphBuilder.h
@@ -101,10 +101,11 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the convolution layer node as a NodeID-Index pair
* @param[in] kernel_spatial_extend Spatial extend of convolution kernels
* @param[in] depth Number of convolution kernels
* @param[in] conv_info Convolution layer information
+ * @param[in] num_groups (Optional) Number of groups for a grouped convolution. Defaults to 1
* @param[in] method (Optional) Convolution method to use
* @param[in] weights_accessor (Optional) Accessor of the weights node data
* @param[in] bias_accessor (Optional) Accessor of the bias node data
@@ -113,13 +114,13 @@ public:
*/
static NodeID add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
- ConvolutionMethod method = ConvolutionMethod::DEFAULT,
+ unsigned int num_groups = 1, ConvolutionMethod method = ConvolutionMethod::DEFAULT,
ITensorAccessorUPtr weights_accessor = nullptr, ITensorAccessorUPtr bias_accessor = nullptr);
/** Adds a depth concatenate node to the graph
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] inputs Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] inputs Inputs to the depth concatenate layer node as a NodeID-Index pair
*
* @return Node ID of the created node, EmptyNodeID in case of error
*/
@@ -128,7 +129,7 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the depthwise convolution layer node as a NodeID-Index pair
* @param[in] kernel_spatial_extend Spatial extend of convolution kernels
* @param[in] conv_info Convolution layer information
* @param[in] method (Optional) Convolution method to use
@@ -156,7 +157,7 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the flatten layer node as a NodeID-Index pair
*
* @return Node ID of the created node, EmptyNodeID in case of error
*/
@@ -165,7 +166,7 @@ public:
*
* @param[in] g Graph to add the layer to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the fully connected layer node as a NodeID-Index pair
* @param[in] num_outputs Number of output neurons
* @param[in] weights_accessor (Optional) Accessor of the weights node data
* @param[in] bias_accessor (Optional) Accessor of the bias node data
@@ -178,7 +179,7 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the normalization layer node as a NodeID-Index pair
* @param[in] norm_info Normalization layer information
*
* @return Node ID of the created node, EmptyNodeID in case of error
@@ -188,7 +189,7 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the pooling layer node as a NodeID-Index pair
* @param[in] pool_info Pooling layer information
*
* @return Node ID of the created node, EmptyNodeID in case of error
@@ -198,7 +199,7 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the reshape layer node as a NodeID-Index pair
* @param[in] shape Output reshaped shape
*
* @return Node ID of the created node, EmptyNodeID in case of error
@@ -208,12 +209,23 @@ public:
*
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
- * @param[in] input Input to the batch normalization layer node as a NodeID-Index pair
+ * @param[in] input Input to the softmax layer node as a NodeID-Index pair
* @param[in] beta Beta parameter
*
* @return Node ID of the created node, EmptyNodeID in case of error
*/
static NodeID add_softmax_node(Graph &g, NodeParams params, NodeIdxPair input, float beta = 1.f);
+ /** Adds a split node to the graph
+ *
+ * @param[in] g Graph to add the node to
+ * @param[in] params Common node parameters
+ * @param[in] input Input to the split layer node as a NodeID-Index pair
+ * @param[in] num_splits Number of different splits
+ * @param[in] axis (Optional) Split axis. Defaults to 0
+ *
+ * @return Node ID of the created node, EmptyNodeID in case of error
+ */
+ static NodeID add_split_node(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_splits, unsigned int axis = 0);
};
} // namespace graph2
} // namespace arm_compute
diff --git a/arm_compute/graph2/IDeviceBackend.h b/arm_compute/graph2/IDeviceBackend.h
index 2e8f3cb252..f0d6297b7b 100644
--- a/arm_compute/graph2/IDeviceBackend.h
+++ b/arm_compute/graph2/IDeviceBackend.h
@@ -65,13 +65,14 @@ public:
virtual std::unique_ptr<ITensorHandle> create_tensor(const Tensor &tensor) = 0;
/** Create a backend Sub-Tensor
*
- * @param[in] parent Parent sub-tensor handle
- * @param[in] shape Shape of the sub-tensor
- * @param[in] coords Starting coordinates of the sub-tensor
+ * @param[in] parent Parent sub-tensor handle
+ * @param[in] shape Shape of the sub-tensor
+ * @param[in] coords Starting coordinates of the sub-tensor
+ * @param[in] extend_parent Extends parent shape if true
*
* @return Backend sub-tensor handle
*/
- virtual std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords) = 0;
+ virtual std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent) = 0;
/** Configure a backend Node
*
* @note This creates an appropriate configured backend function for the given node
diff --git a/arm_compute/graph2/INodeVisitor.h b/arm_compute/graph2/INodeVisitor.h
index a7b8aeb45d..024d83c835 100644
--- a/arm_compute/graph2/INodeVisitor.h
+++ b/arm_compute/graph2/INodeVisitor.h
@@ -116,6 +116,11 @@ public:
* @param[in] n Node to visit.
*/
virtual void visit(SoftmaxLayerNode &n) = 0;
+ /** Visit SplitLayerNode.
+ *
+ * @param[in] n Node to visit.
+ */
+ virtual void visit(SplitLayerNode &n) = 0;
};
/** Default visitor implementation
@@ -195,6 +200,10 @@ public:
{
default_visit();
}
+ virtual void visit(SplitLayerNode &n) override
+ {
+ default_visit();
+ }
#endif /* DOXYGEN_SKIP_THIS */
/** Function to be overloaded by the client and implement default behavior for the
diff --git a/arm_compute/graph2/Types.h b/arm_compute/graph2/Types.h
index 2e9fe38380..4cbfc7267b 100644
--- a/arm_compute/graph2/Types.h
+++ b/arm_compute/graph2/Types.h
@@ -38,6 +38,7 @@ namespace graph2
{
using arm_compute::Status;
+using arm_compute::Coordinates;
using arm_compute::DataType;
using arm_compute::TensorShape;
using arm_compute::Size2D;
@@ -125,6 +126,7 @@ enum class NodeType
PoolingLayer,
ReshapeLayer,
SoftmaxLayer,
+ SplitLayer,
Input,
Output,
diff --git a/arm_compute/graph2/backends/CL/CLDeviceBackend.h b/arm_compute/graph2/backends/CL/CLDeviceBackend.h
index 77a8faf2c6..3a70f0b112 100644
--- a/arm_compute/graph2/backends/CL/CLDeviceBackend.h
+++ b/arm_compute/graph2/backends/CL/CLDeviceBackend.h
@@ -55,7 +55,7 @@ public:
void initialize_backend() override;
void setup_backend_context(GraphContext &ctx) override;
std::unique_ptr<ITensorHandle> create_tensor(const Tensor &tensor) override;
- std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords) override;
+ std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent) override;
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
diff --git a/arm_compute/graph2/backends/CL/CLSubTensorHandle.h b/arm_compute/graph2/backends/CL/CLSubTensorHandle.h
index 5584a8ba4f..9910980e59 100644
--- a/arm_compute/graph2/backends/CL/CLSubTensorHandle.h
+++ b/arm_compute/graph2/backends/CL/CLSubTensorHandle.h
@@ -43,8 +43,9 @@ public:
* @param[in] parent_handle Parent tensor handle
* @param[in] shape Sub-Tensor shape
* @param[in] coords Starting coordinates
+ * @param[in] extend_parent Extends parent shape if true
*/
- CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords);
+ CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent = false);
/** Destructor: free the tensor's memory */
~CLSubTensorHandle() = default;
/** Allow instances of this class to be move constructed */
diff --git a/arm_compute/graph2/backends/NEON/NEDeviceBackend.h b/arm_compute/graph2/backends/NEON/NEDeviceBackend.h
index 5d1394b2f3..e81e9d921e 100644
--- a/arm_compute/graph2/backends/NEON/NEDeviceBackend.h
+++ b/arm_compute/graph2/backends/NEON/NEDeviceBackend.h
@@ -44,7 +44,7 @@ public:
void initialize_backend() override;
void setup_backend_context(GraphContext &ctx) override;
std::unique_ptr<ITensorHandle> create_tensor(const Tensor &tensor) override;
- std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords) override;
+ std::unique_ptr<ITensorHandle> create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent) override;
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
diff --git a/arm_compute/graph2/backends/NEON/NESubTensorHandle.h b/arm_compute/graph2/backends/NEON/NESubTensorHandle.h
index e027b0cc56..eacdfe0fb4 100644
--- a/arm_compute/graph2/backends/NEON/NESubTensorHandle.h
+++ b/arm_compute/graph2/backends/NEON/NESubTensorHandle.h
@@ -43,8 +43,9 @@ public:
* @param[in] parent_handle Parent tensor handle
* @param[in] shape Sub-Tensor shape
* @param[in] coords Starting coordinates
+ * @param[in] extend_parent Extends parent shape if true
*/
- NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords);
+ NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent = false);
/** Destructor: free the tensor's memory */
~NESubTensorHandle() = default;
/** Allow instances of this class to be move constructed */
diff --git a/arm_compute/graph2/frontend/Layers.h b/arm_compute/graph2/frontend/Layers.h
index 7ea23e0684..779b471b52 100644
--- a/arm_compute/graph2/frontend/Layers.h
+++ b/arm_compute/graph2/frontend/Layers.h
@@ -187,11 +187,10 @@ public:
NodeID create_layer(IStream &s) override
{
- ARM_COMPUTE_UNUSED(_num_groups);
NodeIdxPair input = { s.tail_node(), 0 };
NodeParams common_params = { "", s.hints().target_hint };
return GraphBuilder::add_convolution_node(s.graph(), common_params, input,
- Size2D(_conv_width, _conv_height), _ofm, _conv_info,
+ Size2D(_conv_width, _conv_height), _ofm, _conv_info, _num_groups,
s.hints().convolution_method_hint,
std::move(_weights), std::move(_bias));
}
diff --git a/arm_compute/graph2/mutators/GraphMutators.h b/arm_compute/graph2/mutators/GraphMutators.h
index b432e329e2..3275e32961 100644
--- a/arm_compute/graph2/mutators/GraphMutators.h
+++ b/arm_compute/graph2/mutators/GraphMutators.h
@@ -27,5 +27,6 @@
#include "arm_compute/graph2/mutators/DepthConcatSubTensorMutator.h"
#include "arm_compute/graph2/mutators/InPlaceOperationMutator.h"
#include "arm_compute/graph2/mutators/NodeFusionMutator.h"
+#include "arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h"
#endif /* __ARM_COMPUTE_GRAPH2_GRAPH_MUTATORS_H__ */
diff --git a/arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h b/arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h
new file mode 100644
index 0000000000..82ee509a32
--- /dev/null
+++ b/arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_SUBTENSOR_MUTATOR_H__
+#define __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_SUBTENSOR_MUTATOR_H__
+
+#include "arm_compute/graph2/IGraphMutator.h"
+
+namespace arm_compute
+{
+namespace graph2
+{
+/** Mutation pass to optimize split operations by using sub-tensors
+ *
+ * @warning This is compulsory to run in case Split layers are present in the model
+ **/
+class SplitLayerSubTensorMutator final : public IGraphMutator
+{
+public:
+ // Inherited methods overridden
+ virtual void mutate(Graph &g) override;
+ const char *name() override;
+};
+} // namespace graph2
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_SUBTENSOR_MUTATOR_H__ */
diff --git a/arm_compute/graph2/nodes/Nodes.h b/arm_compute/graph2/nodes/Nodes.h
index 8201361304..3786978661 100644
--- a/arm_compute/graph2/nodes/Nodes.h
+++ b/arm_compute/graph2/nodes/Nodes.h
@@ -39,5 +39,6 @@
#include "arm_compute/graph2/nodes/PoolingLayerNode.h"
#include "arm_compute/graph2/nodes/ReshapeLayerNode.h"
#include "arm_compute/graph2/nodes/SoftmaxLayerNode.h"
+#include "arm_compute/graph2/nodes/SplitLayerNode.h"
#endif /* __ARM_COMPUTE_GRAPH2_NODES_H__ */
diff --git a/arm_compute/graph2/nodes/NodesFwd.h b/arm_compute/graph2/nodes/NodesFwd.h
index 03ca65e056..08f2454cde 100644
--- a/arm_compute/graph2/nodes/NodesFwd.h
+++ b/arm_compute/graph2/nodes/NodesFwd.h
@@ -45,6 +45,7 @@ class OutputNode;
class PoolingLayerNode;
class ReshapeLayerNode;
class SoftmaxLayerNode;
+class SplitLayerNode;
} // namespace graph2
} // namespace arm_compute
#endif /* __ARM_COMPUTE_GRAPH2_NODES_FWD_H__ */
diff --git a/arm_compute/graph2/nodes/SplitLayerNode.h b/arm_compute/graph2/nodes/SplitLayerNode.h
new file mode 100644
index 0000000000..90e6134ac0
--- /dev/null
+++ b/arm_compute/graph2/nodes/SplitLayerNode.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_NODE_H__
+#define __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_NODE_H__
+
+#include "arm_compute/graph2/INode.h"
+
+#include <tuple>
+
+namespace arm_compute
+{
+namespace graph2
+{
+/** Split Layer node */
+class SplitLayerNode final : public INode
+{
+public:
+ /** Default Constructor
+ *
+ * @param[in] num_splits Number of splits
+ * @param[in] axis (Optional) Axis to split on. Supported axis >= 2. Defaults to 0
+ */
+ SplitLayerNode(unsigned int num_splits, unsigned int axis = 0);
+ /** Computes split layer output shape
+ *
+ * @param[in] input_shape Shape of the input
+ * @param[in] num_splits Number of splits
+ * @param[in] axis Axis to perform the split on
+ * @param[in] idx Index of the split
+ *
+ * @return A pair with the shape of the split and the starting coordinates
+ */
+ static std::pair<TensorShape, Coordinates> compute_output_shape(TensorShape input_shape, unsigned int num_splits, unsigned int axis, unsigned int idx);
+ /** Number of splits accessor
+ *
+ * @return Number of splits
+ */
+ unsigned int num_splits() const;
+ /** Split axis accessor
+ *
+ * @return Split axis
+ */
+ unsigned int axis() const;
+
+ // Inherited overridden methods:
+ Status validate() override;
+ NodeType type() const override;
+ bool forward_descriptors() override;
+ TensorDescriptor configure_output(size_t idx) const override;
+ void accept(INodeVisitor &v) override;
+
+private:
+ unsigned int _num_splits;
+ unsigned int _axis;
+};
+} // namespace graph2
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_GRAPH2_SPLIT_LAYER_NODE_H__ */
diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp
index f887f97a12..6ba3ebc7ae 100644
--- a/examples/graph_alexnet.cpp
+++ b/examples/graph_alexnet.cpp
@@ -21,8 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/graph/Graph.h"
-#include "arm_compute/graph/Nodes.h"
+#include "arm_compute/graph2.h"
#include "support/ToolchainSupport.h"
#include "utils/GraphUtils.h"
#include "utils/Utils.h"
@@ -32,7 +31,7 @@
#include <memory>
using namespace arm_compute::utils;
-using namespace arm_compute::graph;
+using namespace arm_compute::graph2::frontend;
using namespace arm_compute::graph_utils;
/** Example demonstrating how to implement AlexNet's network using the Compute Library's graph API
@@ -54,13 +53,16 @@ public:
std::unique_ptr<IPreprocessor> preprocessor = arm_compute::support::cpp14::make_unique<CaffePreproccessor>(mean_rgb);
// Set target. 0 (NEON), 1 (OpenCL), 2 (OpenCL with Tuner). By default it is NEON
- const int int_target_hint = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0;
- TargetHint target_hint = set_target_hint(int_target_hint);
+ const int target = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0;
+ Target target_hint = set_target_hint2(target);
+ bool enable_tuning = (target == 2);
+ bool enable_memory_management = true;
- const bool is_gemm_convolution5x5 = Graph::gpu_target() == arm_compute::GPUTarget::MIDGARD || target_hint == TargetHint::NEON;
- const bool is_winograd_convolution3x3 = target_hint == TargetHint::OPENCL;
- ConvolutionMethodHint convolution_5x5_hint = is_gemm_convolution5x5 ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
- ConvolutionMethodHint convolution_3x3_hint = is_winograd_convolution3x3 ? ConvolutionMethodHint::WINOGRAD : ConvolutionMethodHint::GEMM;
+ // TODO (geopin01) : Get GPU target somehow and set gemm also for midgard ?
+ const bool is_gemm_convolution5x5 = (target_hint == Target::NEON);
+ const bool is_winograd_convolution3x3 = target_hint == Target::CL;
+ ConvolutionMethod convolution_5x5_hint = is_gemm_convolution5x5 ? ConvolutionMethod::GEMM : ConvolutionMethod::DIRECT;
+ ConvolutionMethod convolution_3x3_hint = is_winograd_convolution3x3 ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM;
// Parse arguments
if(argc < 2)
@@ -95,8 +97,8 @@ public:
}
graph << target_hint
- << Tensor(TensorInfo(TensorShape(227U, 227U, 3U, 1U), 1, DataType::F32),
- get_input_accessor(image, std::move(preprocessor)))
+ << InputLayer(TensorDescriptor(TensorShape(227U, 227U, 3U, 1U), DataType::F32),
+ get_input_accessor(image, std::move(preprocessor)))
// Layer 1
<< ConvolutionLayer(
11U, 11U, 96U,
@@ -158,10 +160,10 @@ public:
get_weights_accessor(data_path, "/cnn_data/alexnet_model/fc8_b.npy"))
// Softmax
<< SoftmaxLayer()
- << Tensor(get_output_accessor(label, 5));
+ << OutputLayer(get_output_accessor(label, 5));
- // In order to enable the OpenCL tuner, graph_init() has to be called only when all nodes have been instantiated
- graph.graph_init(int_target_hint == 2);
+ // Finalize graph
+ graph.finalize(target_hint, enable_tuning, enable_memory_management);
}
void do_run() override
{
@@ -170,7 +172,7 @@ public:
}
private:
- Graph graph{};
+ Stream graph{ 0, "AlexNet" };
};
/** Main program for AlexNet
diff --git a/src/graph2/GraphBuilder.cpp b/src/graph2/GraphBuilder.cpp
index aaf70c4e61..e6fc2afe21 100644
--- a/src/graph2/GraphBuilder.cpp
+++ b/src/graph2/GraphBuilder.cpp
@@ -46,6 +46,7 @@ Status set_node_params(Graph &g, NodeID nid, NodeParams &params)
return Status{};
}
+
Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, ITensorAccessorUPtr accessor)
{
INode *node = g.node(nid);
@@ -66,6 +67,55 @@ NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &
set_node_params(g, nid, params);
return nid;
}
+
+template <typename NT, typename... Args>
+NodeID create_simple_single_input_output_node(Graph &g, NodeParams &params, NodeIdxPair input, Args &&... args)
+{
+ CHECK_NODEIDX_PAIR(input, g);
+
+ NodeID nid = g.add_node<NT>(std::forward<Args>(args)...);
+ g.add_connection(input.node_id, input.index, nid, 0);
+ set_node_params(g, nid, params);
+
+ return nid;
+}
+
+NodeID create_grouped_convolution(Graph &g, NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
+ PadStrideInfo conv_info, ConvolutionMethod method, unsigned int num_groups)
+{
+ bool has_bias = (bias != EmptyNodeID);
+
+ // Split input
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
+
+ // Split weights
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
+
+ // Split bias
+ NodeID bias_split = EmptyNodeID;
+ if(has_bias)
+ {
+ // Split bias
+ bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
+ }
+
+ std::vector<NodeIdxPair> convolution_outputs;
+ for(unsigned int i = 0; i < num_groups; ++i)
+ {
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input_split, i, conv_nid, 0);
+ g.add_connection(weights_split, i, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(bias_split, i, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+ convolution_outputs.push_back({ conv_nid, 0 });
+ }
+
+ // Depth concatenate output
+ return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
+}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -98,13 +148,7 @@ NodeID GraphBuilder::add_output_node(Graph &g, NodeParams params, NodeIdxPair in
NodeID GraphBuilder::add_activation_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ActivationLayerNode>(act_info);
- g.add_connection(input.node_id, input.index, nid, 0);
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ActivationLayerNode>(g, params, input, act_info);
}
NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, float epsilon,
@@ -161,7 +205,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N
NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
- ConvolutionMethod method,
+ unsigned int num_groups, ConvolutionMethod method,
ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor)
{
CHECK_NODEIDX_PAIR(input, g);
@@ -175,7 +219,7 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z(), depth);
+ w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth);
NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
@@ -187,17 +231,24 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
}
- // Create convolution node and connect
- NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
- g.add_connection(input.node_id, input.index, conv_nid, 0);
- g.add_connection(w_nid, 0, conv_nid, 1);
- if(has_bias)
+ if(num_groups == 1)
{
- g.add_connection(b_nid, 0, conv_nid, 2);
+ // Create convolution node and connect
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input.node_id, input.index, conv_nid, 0);
+ g.add_connection(w_nid, 0, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(b_nid, 0, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+
+ return conv_nid;
+ }
+ else
+ {
+ return create_grouped_convolution(g, params, input, w_nid, b_nid, conv_info, method, num_groups);
}
- set_node_params(g, conv_nid, params);
-
- return conv_nid;
}
NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
@@ -273,14 +324,7 @@ NodeID GraphBuilder::add_elementwise_node(Graph &g, NodeParams params, NodeIdxPa
NodeID GraphBuilder::add_flatten_node(Graph &g, NodeParams params, NodeIdxPair input)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<FlattenLayerNode>();
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<FlattenLayerNode>(g, params, input);
}
NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs,
@@ -324,50 +368,27 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
NodeID GraphBuilder::add_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, NormalizationLayerInfo norm_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<NormalizationLayerNode>(norm_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<NormalizationLayerNode>(g, params, input, norm_info);
}
NodeID GraphBuilder::add_pooling_node(Graph &g, NodeParams params, NodeIdxPair input, PoolingLayerInfo pool_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<PoolingLayerNode>(pool_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<PoolingLayerNode>(g, params, input, pool_info);
}
NodeID GraphBuilder::add_reshape_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ReshapeLayerNode>(shape);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ReshapeLayerNode>(g, params, input, shape);
}
NodeID GraphBuilder::add_softmax_node(Graph &g, NodeParams params, NodeIdxPair input, float beta)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<SoftmaxLayerNode>(beta);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
+ return create_simple_single_input_output_node<SoftmaxLayerNode>(g, params, input, beta);
+}
- return nid;
+NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_splits, unsigned int axis)
+{
+ return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
} // namespace graph2
} // namespace arm_compute \ No newline at end of file
diff --git a/src/graph2/Utils.cpp b/src/graph2/Utils.cpp
index a518c80da8..3ff400bf61 100644
--- a/src/graph2/Utils.cpp
+++ b/src/graph2/Utils.cpp
@@ -77,6 +77,7 @@ PassManager create_default_pass_manager()
pm.append(support::cpp14::make_unique<InPlaceOperationMutator>());
pm.append(support::cpp14::make_unique<NodeFusionMutator>());
+ pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>());
pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>());
return pm;
diff --git a/src/graph2/backends/CL/CLDeviceBackend.cpp b/src/graph2/backends/CL/CLDeviceBackend.cpp
index 28e053415b..6d2d4f9b1a 100644
--- a/src/graph2/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph2/backends/CL/CLDeviceBackend.cpp
@@ -127,14 +127,14 @@ std::unique_ptr<ITensorHandle> CLDeviceBackend::create_tensor(const Tensor &tens
return std::move(backend_tensor_handle);
}
-std::unique_ptr<ITensorHandle> CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords)
+std::unique_ptr<ITensorHandle> CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)
{
if(parent == nullptr)
{
return nullptr;
}
- return support::cpp14::make_unique<CLSubTensorHandle>(parent, shape, coords);
+ return support::cpp14::make_unique<CLSubTensorHandle>(parent, shape, coords, extend_parent);
}
std::unique_ptr<arm_compute::IFunction> CLDeviceBackend::configure_node(INode &node, GraphContext &ctx)
diff --git a/src/graph2/backends/CL/CLSubTensorHandle.cpp b/src/graph2/backends/CL/CLSubTensorHandle.cpp
index 2954652d71..a001d57832 100644
--- a/src/graph2/backends/CL/CLSubTensorHandle.cpp
+++ b/src/graph2/backends/CL/CLSubTensorHandle.cpp
@@ -31,12 +31,12 @@ namespace graph2
{
namespace backends
{
-CLSubTensorHandle::CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords)
+CLSubTensorHandle::CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent)
: _sub_tensor()
{
ARM_COMPUTE_ERROR_ON(!parent_handle);
auto parent_tensor = arm_compute::utils::cast::polymorphic_downcast<ICLTensor *>(&parent_handle->tensor());
- _sub_tensor = arm_compute::CLSubTensor(parent_tensor, shape, coords);
+ _sub_tensor = arm_compute::CLSubTensor(parent_tensor, shape, coords, extend_parent);
}
void CLSubTensorHandle::allocate()
diff --git a/src/graph2/backends/NEON/NEDeviceBackend.cpp b/src/graph2/backends/NEON/NEDeviceBackend.cpp
index 5569abf41b..9010c5d802 100644
--- a/src/graph2/backends/NEON/NEDeviceBackend.cpp
+++ b/src/graph2/backends/NEON/NEDeviceBackend.cpp
@@ -86,14 +86,14 @@ std::unique_ptr<ITensorHandle> NEDeviceBackend::create_tensor(const Tensor &tens
return std::move(backend_tensor_handle);
}
-std::unique_ptr<ITensorHandle> NEDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords)
+std::unique_ptr<ITensorHandle> NEDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)
{
if(parent == nullptr)
{
return nullptr;
}
- return support::cpp14::make_unique<NESubTensorHandle>(parent, shape, coords);
+ return support::cpp14::make_unique<NESubTensorHandle>(parent, shape, coords, extend_parent);
}
std::unique_ptr<arm_compute::IFunction> NEDeviceBackend::configure_node(INode &node, GraphContext &ctx)
diff --git a/src/graph2/backends/NEON/NESubTensorHandle.cpp b/src/graph2/backends/NEON/NESubTensorHandle.cpp
index 9b3c9b18d6..491cf8259c 100644
--- a/src/graph2/backends/NEON/NESubTensorHandle.cpp
+++ b/src/graph2/backends/NEON/NESubTensorHandle.cpp
@@ -29,11 +29,11 @@ namespace graph2
{
namespace backends
{
-NESubTensorHandle::NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords)
+NESubTensorHandle::NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent)
: _sub_tensor()
{
ARM_COMPUTE_ERROR_ON(!parent_handle);
- _sub_tensor = arm_compute::SubTensor(&parent_handle->tensor(), shape, coords);
+ _sub_tensor = arm_compute::SubTensor(&parent_handle->tensor(), shape, coords, extend_parent);
}
void NESubTensorHandle::allocate()
diff --git a/src/graph2/mutators/DepthConcatSubTensorMutator.cpp b/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
index cc8de6bb1b..ea3743bf21 100644
--- a/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
@@ -70,7 +70,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
const auto input_shape = input_tensor->desc().shape;
auto backend = backends::BackendRegistry::get().find_backend(input_tensor->desc().target);
- auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth));
+ auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
input_tensor->set_handle(std::move(handle));
depth += input_shape.z();
diff --git a/src/graph2/mutators/SplitLayerSubTensorMutator.cpp b/src/graph2/mutators/SplitLayerSubTensorMutator.cpp
new file mode 100644
index 0000000000..33494ba6bc
--- /dev/null
+++ b/src/graph2/mutators/SplitLayerSubTensorMutator.cpp
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h"
+
+#include "arm_compute/graph2/Graph.h"
+#include "arm_compute/graph2/Logger.h"
+#include "arm_compute/graph2/backends/BackendRegistry.h"
+#include "arm_compute/graph2/nodes/SplitLayerNode.h"
+
+#include "arm_compute/core/utils/misc/Cast.h"
+#include "arm_compute/core/utils/misc/Iterable.h"
+
+namespace arm_compute
+{
+namespace graph2
+{
+const char *SplitLayerSubTensorMutator::name()
+{
+ return "SplitLayerSubTensorMutator";
+}
+
+void SplitLayerSubTensorMutator::mutate(Graph &g)
+{
+ // Should be in reverse order of execution
+ for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ {
+ if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
+ {
+ // Get output tensor
+ Tensor *input_tensor = node->input(0);
+
+ // Check that all tensor have the same target and are valid
+ bool is_valid = std::all_of(node->outputs().cbegin(), node->outputs().cend(),
+ [&](const TensorID & tid)
+ {
+ return (g.tensor(tid) != nullptr) && (g.tensor(tid)->desc().target == input_tensor->desc().target);
+ });
+
+ // Create subtensors
+ if(is_valid && backends::BackendRegistry::get().find_backend(input_tensor->desc().target) != nullptr)
+ {
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
+ << node->id() << " and name : " << node->name() << std::endl);
+
+ auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node.get());
+
+ const unsigned int axis = split_node->axis();
+ const unsigned int num_splits = split_node->num_splits();
+ const bool extend_parent = (axis < 2);
+
+ // Create sub-tensor handles
+ for(unsigned int i = 0; i < node->outputs().size(); ++i)
+ {
+ Tensor *output_tensor = node->output(i);
+ const TensorShape output_shape = output_tensor->desc().shape;
+ Coordinates coords;
+ std::tie(std::ignore, coords) = SplitLayerNode::compute_output_shape(input_tensor->desc().shape, num_splits, axis, i);
+
+ backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(output_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend->create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
+ output_tensor->set_handle(std::move(handle));
+ }
+ }
+ }
+ }
+}
+} // namespace graph2
+} // namespace arm_compute
diff --git a/src/graph2/nodes/SplitLayerNode.cpp b/src/graph2/nodes/SplitLayerNode.cpp
new file mode 100644
index 0000000000..c34a7ff176
--- /dev/null
+++ b/src/graph2/nodes/SplitLayerNode.cpp
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph2/nodes/SplitLayerNode.h"
+
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/graph2/Graph.h"
+#include "arm_compute/graph2/INodeVisitor.h"
+
+namespace arm_compute
+{
+namespace graph2
+{
+SplitLayerNode::SplitLayerNode(unsigned int num_splits, unsigned int axis)
+ : _num_splits(num_splits), _axis(axis)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(num_splits, NullTensorID);
+}
+
+unsigned int SplitLayerNode::num_splits() const
+{
+ return _num_splits;
+}
+
+unsigned int SplitLayerNode::axis() const
+{
+ return _axis;
+}
+
+std::pair<TensorShape, Coordinates> SplitLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_splits, unsigned int axis, unsigned int idx)
+{
+ ARM_COMPUTE_ERROR_ON(axis >= input_shape.num_dimensions());
+ ARM_COMPUTE_ERROR_ON_MSG(input_shape[axis] % num_splits, "Split should be exact");
+
+ const unsigned int split_size = input_shape[axis] / num_splits;
+
+ TensorShape output_shape = input_shape;
+ output_shape.set(axis, split_size);
+
+ Coordinates coords;
+ coords.set(axis, idx * split_size);
+
+ return std::make_pair(output_shape, coords);
+}
+
+bool SplitLayerNode::forward_descriptors()
+{
+ if(input_id(0) != NullTensorID)
+ {
+ for(unsigned int i = 0; i < _outputs.size(); ++i)
+ {
+ if(output_id(i) != NullTensorID)
+ {
+ Tensor *dst_i = output(i);
+ ARM_COMPUTE_ERROR_ON(dst_i == nullptr);
+ dst_i->desc() = configure_output(i);
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor SplitLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ TensorShape output_shape;
+
+ TensorDescriptor output_info = src->desc();
+ std::tie(output_shape, std::ignore) = compute_output_shape(src->desc().shape, _num_splits, _axis, idx);
+ output_info.shape = output_shape;
+
+ return output_info;
+}
+
+Status SplitLayerNode::validate()
+{
+ return Status{};
+}
+
+NodeType SplitLayerNode::type() const
+{
+ return NodeType::SplitLayer;
+}
+
+void SplitLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph2
+} // namespace arm_compute \ No newline at end of file