aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/SplitLayerNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/SplitLayerNode.cpp')
-rw-r--r--src/graph/nodes/SplitLayerNode.cpp81
1 files changed, 61 insertions, 20 deletions
diff --git a/src/graph/nodes/SplitLayerNode.cpp b/src/graph/nodes/SplitLayerNode.cpp
index 5d46c9dcc9..dfb6624f80 100644
--- a/src/graph/nodes/SplitLayerNode.cpp
+++ b/src/graph/nodes/SplitLayerNode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/graph/nodes/SplitLayerNode.h"
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
@@ -31,8 +32,8 @@ namespace arm_compute
{
namespace graph
{
-SplitLayerNode::SplitLayerNode(unsigned int num_splits, unsigned int axis)
- : _num_splits(num_splits), _axis(axis)
+SplitLayerNode::SplitLayerNode(unsigned int num_splits, int axis, std::vector<int> size_splits)
+ : _num_splits(num_splits), _axis(axis), _size_splits(size_splits)
{
_input_edges.resize(1, EmptyEdgeID);
_outputs.resize(num_splits, NullTensorID);
@@ -48,28 +49,47 @@ unsigned int SplitLayerNode::axis() const
return _axis;
}
-std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
- unsigned int num_splits, unsigned int axis, unsigned int idx)
+std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descriptor(
+ const TensorDescriptor &input_descriptor, unsigned int num_splits, int axis, unsigned int idx)
{
- const unsigned int split_size = input_descriptor.shape[axis] / num_splits;
-
+ // Handle negative axis, negative index is used to specify axis from the end (e.g. -1 for the last axis).
+ int num_dimension = static_cast<int32_t>(input_descriptor.shape.num_dimensions());
+ int tmp_axis = wrap_around(axis, num_dimension);
+ Coordinates coords;
TensorDescriptor output_descriptor = input_descriptor;
- output_descriptor.shape.set(axis, split_size);
-
- Coordinates coords;
- coords.set(axis, idx * split_size);
+ int split_size = input_descriptor.shape[tmp_axis] / num_splits;
+ if (_size_splits.empty())
+ {
+ output_descriptor.shape.set(tmp_axis, split_size);
+ coords.set(tmp_axis, idx * split_size);
+ }
+ else
+ {
+ int split_size = _size_splits[idx];
+ if (split_size == -1)
+ {
+ split_size = input_descriptor.shape[tmp_axis];
+ for (unsigned int i = 0; i < _size_splits.size() - 1; ++i)
+ split_size -= _size_splits[i];
+ }
+ output_descriptor.shape.set(tmp_axis, split_size);
+ int coord_value = 0;
+ for (unsigned int i = 0; i < idx; ++i)
+ coord_value += _size_splits[i];
+ coords.set(tmp_axis, coord_value);
+ }
return std::make_pair(output_descriptor, coords);
}
bool SplitLayerNode::forward_descriptors()
{
- if(input_id(0) != NullTensorID)
+ if (input_id(0) != NullTensorID)
{
validate();
- for(unsigned int i = 0; i < _outputs.size(); ++i)
+ for (unsigned int i = 0; i < _outputs.size(); ++i)
{
- if(output_id(i) != NullTensorID)
+ if (output_id(i) != NullTensorID)
{
Tensor *dst_i = output(i);
ARM_COMPUTE_ERROR_ON(dst_i == nullptr);
@@ -89,18 +109,39 @@ TensorDescriptor SplitLayerNode::configure_output(size_t idx) const
const Tensor *src = input(0);
ARM_COMPUTE_ERROR_ON(src == nullptr);
- TensorDescriptor output_info;
- std::tie(output_info, std::ignore) = compute_output_descriptor(src->desc(), _num_splits, _axis, idx);
+ TensorDescriptor input_descriptor = src->desc();
+ TensorDescriptor output_descriptor = input_descriptor;
- return output_info;
+ // Handle negative axis, negative index is used to specify axis from the end (e.g. -1 for the last axis).
+ int num_dimension = static_cast<int32_t>(src->desc().shape.num_dimensions());
+ int tmp_axis = wrap_around(_axis, num_dimension);
+
+ int split_size = (_size_splits.empty()) ? (input_descriptor.shape[tmp_axis] / _num_splits) : _size_splits[idx];
+ if (split_size == -1)
+ {
+ split_size = input_descriptor.shape[tmp_axis];
+ for (unsigned int i = 0; i < _size_splits.size() - 1; ++i)
+ split_size -= _size_splits[i];
+ }
+ output_descriptor.shape.set(tmp_axis, split_size);
+
+ return output_descriptor;
}
Status SplitLayerNode::validate() const
{
const Tensor *src = input(0);
ARM_COMPUTE_RETURN_ERROR_ON(src == nullptr);
- ARM_COMPUTE_RETURN_ERROR_ON(_axis >= src->desc().shape.num_dimensions());
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->desc().shape[_axis] % _num_splits, "Split should be exact");
+ int num_dimension = static_cast<int32_t>(src->desc().shape.num_dimensions());
+ ARM_COMPUTE_RETURN_ERROR_ON(_axis < (-num_dimension) || _axis >= num_dimension);
+
+ // Handle negative axis, negative index is used to specify axis from the end (e.g. -1 for the last axis).
+ int tmp_axis = wrap_around(_axis, num_dimension);
+
+ if (_size_splits.empty())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->desc().shape[tmp_axis] % _num_splits, "Split should be exact");
+ }
return Status{};
}
@@ -115,4 +156,4 @@ void SplitLayerNode::accept(INodeVisitor &v)
v.visit(*this);
}
} // namespace graph
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute