aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-03-08 16:01:29 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commitee33ea5a6e1aa0faac1cc8b5a269bd4f89854821 (patch)
tree0baf159ae4a61d07cc765ad6bb1a2fb42c403081 /src
parente86a09fe4c5aa9037787e13ee55cba2b049d5ea5 (diff)
downloadComputeLibrary-ee33ea5a6e1aa0faac1cc8b5a269bd4f89854821.tar.gz
COMPMID-996: Add support for grouped convolution.
Change-Id: I279e29ce20b3dde57445264dc11491f127b44d70 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124429 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/graph2/GraphBuilder.cpp135
-rw-r--r--src/graph2/Utils.cpp1
-rw-r--r--src/graph2/backends/CL/CLDeviceBackend.cpp4
-rw-r--r--src/graph2/backends/CL/CLSubTensorHandle.cpp4
-rw-r--r--src/graph2/backends/NEON/NEDeviceBackend.cpp4
-rw-r--r--src/graph2/backends/NEON/NESubTensorHandle.cpp4
-rw-r--r--src/graph2/mutators/DepthConcatSubTensorMutator.cpp2
-rw-r--r--src/graph2/mutators/SplitLayerSubTensorMutator.cpp89
-rw-r--r--src/graph2/nodes/SplitLayerNode.cpp117
9 files changed, 294 insertions, 66 deletions
diff --git a/src/graph2/GraphBuilder.cpp b/src/graph2/GraphBuilder.cpp
index aaf70c4e61..e6fc2afe21 100644
--- a/src/graph2/GraphBuilder.cpp
+++ b/src/graph2/GraphBuilder.cpp
@@ -46,6 +46,7 @@ Status set_node_params(Graph &g, NodeID nid, NodeParams &params)
return Status{};
}
+
Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, ITensorAccessorUPtr accessor)
{
INode *node = g.node(nid);
@@ -66,6 +67,55 @@ NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &
set_node_params(g, nid, params);
return nid;
}
+
+template <typename NT, typename... Args>
+NodeID create_simple_single_input_output_node(Graph &g, NodeParams &params, NodeIdxPair input, Args &&... args)
+{
+ CHECK_NODEIDX_PAIR(input, g);
+
+ NodeID nid = g.add_node<NT>(std::forward<Args>(args)...);
+ g.add_connection(input.node_id, input.index, nid, 0);
+ set_node_params(g, nid, params);
+
+ return nid;
+}
+
+NodeID create_grouped_convolution(Graph &g, NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
+ PadStrideInfo conv_info, ConvolutionMethod method, unsigned int num_groups)
+{
+ bool has_bias = (bias != EmptyNodeID);
+
+ // Split input
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
+
+ // Split weights
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
+
+ // Split bias
+ NodeID bias_split = EmptyNodeID;
+ if(has_bias)
+ {
+ // Split bias
+ bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
+ }
+
+ std::vector<NodeIdxPair> convolution_outputs;
+ for(unsigned int i = 0; i < num_groups; ++i)
+ {
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input_split, i, conv_nid, 0);
+ g.add_connection(weights_split, i, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(bias_split, i, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+ convolution_outputs.push_back({ conv_nid, 0 });
+ }
+
+ // Depth concatenate output
+ return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
+}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -98,13 +148,7 @@ NodeID GraphBuilder::add_output_node(Graph &g, NodeParams params, NodeIdxPair in
NodeID GraphBuilder::add_activation_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ActivationLayerNode>(act_info);
- g.add_connection(input.node_id, input.index, nid, 0);
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ActivationLayerNode>(g, params, input, act_info);
}
NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, float epsilon,
@@ -161,7 +205,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N
NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
- ConvolutionMethod method,
+ unsigned int num_groups, ConvolutionMethod method,
ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor)
{
CHECK_NODEIDX_PAIR(input, g);
@@ -175,7 +219,7 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z(), depth);
+ w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth);
NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
@@ -187,17 +231,24 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
}
- // Create convolution node and connect
- NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
- g.add_connection(input.node_id, input.index, conv_nid, 0);
- g.add_connection(w_nid, 0, conv_nid, 1);
- if(has_bias)
+ if(num_groups == 1)
{
- g.add_connection(b_nid, 0, conv_nid, 2);
+ // Create convolution node and connect
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input.node_id, input.index, conv_nid, 0);
+ g.add_connection(w_nid, 0, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(b_nid, 0, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+
+ return conv_nid;
+ }
+ else
+ {
+ return create_grouped_convolution(g, params, input, w_nid, b_nid, conv_info, method, num_groups);
}
- set_node_params(g, conv_nid, params);
-
- return conv_nid;
}
NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
@@ -273,14 +324,7 @@ NodeID GraphBuilder::add_elementwise_node(Graph &g, NodeParams params, NodeIdxPa
NodeID GraphBuilder::add_flatten_node(Graph &g, NodeParams params, NodeIdxPair input)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<FlattenLayerNode>();
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<FlattenLayerNode>(g, params, input);
}
NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs,
@@ -324,50 +368,27 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
NodeID GraphBuilder::add_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, NormalizationLayerInfo norm_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<NormalizationLayerNode>(norm_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<NormalizationLayerNode>(g, params, input, norm_info);
}
NodeID GraphBuilder::add_pooling_node(Graph &g, NodeParams params, NodeIdxPair input, PoolingLayerInfo pool_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<PoolingLayerNode>(pool_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<PoolingLayerNode>(g, params, input, pool_info);
}
NodeID GraphBuilder::add_reshape_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ReshapeLayerNode>(shape);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ReshapeLayerNode>(g, params, input, shape);
}
NodeID GraphBuilder::add_softmax_node(Graph &g, NodeParams params, NodeIdxPair input, float beta)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<SoftmaxLayerNode>(beta);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
+ return create_simple_single_input_output_node<SoftmaxLayerNode>(g, params, input, beta);
+}
- return nid;
+NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_splits, unsigned int axis)
+{
+ return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
} // namespace graph2
} // namespace arm_compute \ No newline at end of file
diff --git a/src/graph2/Utils.cpp b/src/graph2/Utils.cpp
index a518c80da8..3ff400bf61 100644
--- a/src/graph2/Utils.cpp
+++ b/src/graph2/Utils.cpp
@@ -77,6 +77,7 @@ PassManager create_default_pass_manager()
pm.append(support::cpp14::make_unique<InPlaceOperationMutator>());
pm.append(support::cpp14::make_unique<NodeFusionMutator>());
+ pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>());
pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>());
return pm;
diff --git a/src/graph2/backends/CL/CLDeviceBackend.cpp b/src/graph2/backends/CL/CLDeviceBackend.cpp
index 28e053415b..6d2d4f9b1a 100644
--- a/src/graph2/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph2/backends/CL/CLDeviceBackend.cpp
@@ -127,14 +127,14 @@ std::unique_ptr<ITensorHandle> CLDeviceBackend::create_tensor(const Tensor &tens
return std::move(backend_tensor_handle);
}
-std::unique_ptr<ITensorHandle> CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords)
+std::unique_ptr<ITensorHandle> CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)
{
if(parent == nullptr)
{
return nullptr;
}
- return support::cpp14::make_unique<CLSubTensorHandle>(parent, shape, coords);
+ return support::cpp14::make_unique<CLSubTensorHandle>(parent, shape, coords, extend_parent);
}
std::unique_ptr<arm_compute::IFunction> CLDeviceBackend::configure_node(INode &node, GraphContext &ctx)
diff --git a/src/graph2/backends/CL/CLSubTensorHandle.cpp b/src/graph2/backends/CL/CLSubTensorHandle.cpp
index 2954652d71..a001d57832 100644
--- a/src/graph2/backends/CL/CLSubTensorHandle.cpp
+++ b/src/graph2/backends/CL/CLSubTensorHandle.cpp
@@ -31,12 +31,12 @@ namespace graph2
{
namespace backends
{
-CLSubTensorHandle::CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords)
+CLSubTensorHandle::CLSubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent)
: _sub_tensor()
{
ARM_COMPUTE_ERROR_ON(!parent_handle);
auto parent_tensor = arm_compute::utils::cast::polymorphic_downcast<ICLTensor *>(&parent_handle->tensor());
- _sub_tensor = arm_compute::CLSubTensor(parent_tensor, shape, coords);
+ _sub_tensor = arm_compute::CLSubTensor(parent_tensor, shape, coords, extend_parent);
}
void CLSubTensorHandle::allocate()
diff --git a/src/graph2/backends/NEON/NEDeviceBackend.cpp b/src/graph2/backends/NEON/NEDeviceBackend.cpp
index 5569abf41b..9010c5d802 100644
--- a/src/graph2/backends/NEON/NEDeviceBackend.cpp
+++ b/src/graph2/backends/NEON/NEDeviceBackend.cpp
@@ -86,14 +86,14 @@ std::unique_ptr<ITensorHandle> NEDeviceBackend::create_tensor(const Tensor &tens
return std::move(backend_tensor_handle);
}
-std::unique_ptr<ITensorHandle> NEDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords)
+std::unique_ptr<ITensorHandle> NEDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)
{
if(parent == nullptr)
{
return nullptr;
}
- return support::cpp14::make_unique<NESubTensorHandle>(parent, shape, coords);
+ return support::cpp14::make_unique<NESubTensorHandle>(parent, shape, coords, extend_parent);
}
std::unique_ptr<arm_compute::IFunction> NEDeviceBackend::configure_node(INode &node, GraphContext &ctx)
diff --git a/src/graph2/backends/NEON/NESubTensorHandle.cpp b/src/graph2/backends/NEON/NESubTensorHandle.cpp
index 9b3c9b18d6..491cf8259c 100644
--- a/src/graph2/backends/NEON/NESubTensorHandle.cpp
+++ b/src/graph2/backends/NEON/NESubTensorHandle.cpp
@@ -29,11 +29,11 @@ namespace graph2
{
namespace backends
{
-NESubTensorHandle::NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords)
+NESubTensorHandle::NESubTensorHandle(ITensorHandle *parent_handle, const TensorShape &shape, const Coordinates &coords, bool extend_parent)
: _sub_tensor()
{
ARM_COMPUTE_ERROR_ON(!parent_handle);
- _sub_tensor = arm_compute::SubTensor(&parent_handle->tensor(), shape, coords);
+ _sub_tensor = arm_compute::SubTensor(&parent_handle->tensor(), shape, coords, extend_parent);
}
void NESubTensorHandle::allocate()
diff --git a/src/graph2/mutators/DepthConcatSubTensorMutator.cpp b/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
index cc8de6bb1b..ea3743bf21 100644
--- a/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph2/mutators/DepthConcatSubTensorMutator.cpp
@@ -70,7 +70,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
const auto input_shape = input_tensor->desc().shape;
auto backend = backends::BackendRegistry::get().find_backend(input_tensor->desc().target);
- auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth));
+ auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
input_tensor->set_handle(std::move(handle));
depth += input_shape.z();
diff --git a/src/graph2/mutators/SplitLayerSubTensorMutator.cpp b/src/graph2/mutators/SplitLayerSubTensorMutator.cpp
new file mode 100644
index 0000000000..33494ba6bc
--- /dev/null
+++ b/src/graph2/mutators/SplitLayerSubTensorMutator.cpp
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph2/mutators/SplitLayerSubTensorMutator.h"
+
+#include "arm_compute/graph2/Graph.h"
+#include "arm_compute/graph2/Logger.h"
+#include "arm_compute/graph2/backends/BackendRegistry.h"
+#include "arm_compute/graph2/nodes/SplitLayerNode.h"
+
+#include "arm_compute/core/utils/misc/Cast.h"
+#include "arm_compute/core/utils/misc/Iterable.h"
+
+namespace arm_compute
+{
+namespace graph2
+{
+const char *SplitLayerSubTensorMutator::name()
+{
+ return "SplitLayerSubTensorMutator";
+}
+
+void SplitLayerSubTensorMutator::mutate(Graph &g)
+{
+ // Should be in reverse order of execution
+ for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ {
+ if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
+ {
+ // Get output tensor
+ Tensor *input_tensor = node->input(0);
+
+ // Check that all tensor have the same target and are valid
+ bool is_valid = std::all_of(node->outputs().cbegin(), node->outputs().cend(),
+ [&](const TensorID & tid)
+ {
+ return (g.tensor(tid) != nullptr) && (g.tensor(tid)->desc().target == input_tensor->desc().target);
+ });
+
+ // Create subtensors
+ if(is_valid && backends::BackendRegistry::get().find_backend(input_tensor->desc().target) != nullptr)
+ {
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
+ << node->id() << " and name : " << node->name() << std::endl);
+
+ auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node.get());
+
+ const unsigned int axis = split_node->axis();
+ const unsigned int num_splits = split_node->num_splits();
+ const bool extend_parent = (axis < 2);
+
+ // Create sub-tensor handles
+ for(unsigned int i = 0; i < node->outputs().size(); ++i)
+ {
+ Tensor *output_tensor = node->output(i);
+ const TensorShape output_shape = output_tensor->desc().shape;
+ Coordinates coords;
+ std::tie(std::ignore, coords) = SplitLayerNode::compute_output_shape(input_tensor->desc().shape, num_splits, axis, i);
+
+ backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(output_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend->create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
+ output_tensor->set_handle(std::move(handle));
+ }
+ }
+ }
+ }
+}
+} // namespace graph2
+} // namespace arm_compute
diff --git a/src/graph2/nodes/SplitLayerNode.cpp b/src/graph2/nodes/SplitLayerNode.cpp
new file mode 100644
index 0000000000..c34a7ff176
--- /dev/null
+++ b/src/graph2/nodes/SplitLayerNode.cpp
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph2/nodes/SplitLayerNode.h"
+
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/graph2/Graph.h"
+#include "arm_compute/graph2/INodeVisitor.h"
+
+namespace arm_compute
+{
+namespace graph2
+{
+SplitLayerNode::SplitLayerNode(unsigned int num_splits, unsigned int axis)
+ : _num_splits(num_splits), _axis(axis)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(num_splits, NullTensorID);
+}
+
+unsigned int SplitLayerNode::num_splits() const
+{
+ return _num_splits;
+}
+
+unsigned int SplitLayerNode::axis() const
+{
+ return _axis;
+}
+
+std::pair<TensorShape, Coordinates> SplitLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_splits, unsigned int axis, unsigned int idx)
+{
+ ARM_COMPUTE_ERROR_ON(axis >= input_shape.num_dimensions());
+ ARM_COMPUTE_ERROR_ON_MSG(input_shape[axis] % num_splits, "Split should be exact");
+
+ const unsigned int split_size = input_shape[axis] / num_splits;
+
+ TensorShape output_shape = input_shape;
+ output_shape.set(axis, split_size);
+
+ Coordinates coords;
+ coords.set(axis, idx * split_size);
+
+ return std::make_pair(output_shape, coords);
+}
+
+bool SplitLayerNode::forward_descriptors()
+{
+ if(input_id(0) != NullTensorID)
+ {
+ for(unsigned int i = 0; i < _outputs.size(); ++i)
+ {
+ if(output_id(i) != NullTensorID)
+ {
+ Tensor *dst_i = output(i);
+ ARM_COMPUTE_ERROR_ON(dst_i == nullptr);
+ dst_i->desc() = configure_output(i);
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor SplitLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ TensorShape output_shape;
+
+ TensorDescriptor output_info = src->desc();
+ std::tie(output_shape, std::ignore) = compute_output_shape(src->desc().shape, _num_splits, _axis, idx);
+ output_info.shape = output_shape;
+
+ return output_info;
+}
+
+Status SplitLayerNode::validate()
+{
+ return Status{};
+}
+
+NodeType SplitLayerNode::type() const
+{
+ return NodeType::SplitLayer;
+}
+
+void SplitLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph2
+} // namespace arm_compute \ No newline at end of file