diff options
Diffstat (limited to 'src/graph/nodes/SplitLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/SplitLayerNode.cpp | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/src/graph/nodes/SplitLayerNode.cpp b/src/graph/nodes/SplitLayerNode.cpp index 31931c3a79..dfb6624f80 100644 --- a/src/graph/nodes/SplitLayerNode.cpp +++ b/src/graph/nodes/SplitLayerNode.cpp @@ -49,8 +49,8 @@ unsigned int SplitLayerNode::axis() const return _axis; } -std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, - unsigned int num_splits, int axis, unsigned int idx) +std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descriptor( + const TensorDescriptor &input_descriptor, unsigned int num_splits, int axis, unsigned int idx) { // Handle negative axis, negative index is used to specify axis from the end (e.g. -1 for the last axis). int num_dimension = static_cast<int32_t>(input_descriptor.shape.num_dimensions()); @@ -58,7 +58,7 @@ std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descript Coordinates coords; TensorDescriptor output_descriptor = input_descriptor; int split_size = input_descriptor.shape[tmp_axis] / num_splits; - if(_size_splits.empty()) + if (_size_splits.empty()) { output_descriptor.shape.set(tmp_axis, split_size); coords.set(tmp_axis, idx * split_size); @@ -66,15 +66,15 @@ std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descript else { int split_size = _size_splits[idx]; - if(split_size == -1) + if (split_size == -1) { split_size = input_descriptor.shape[tmp_axis]; - for(unsigned int i = 0; i < _size_splits.size() - 1; ++i) + for (unsigned int i = 0; i < _size_splits.size() - 1; ++i) split_size -= _size_splits[i]; } output_descriptor.shape.set(tmp_axis, split_size); int coord_value = 0; - for(unsigned int i = 0; i < idx; ++i) + for (unsigned int i = 0; i < idx; ++i) coord_value += _size_splits[i]; coords.set(tmp_axis, coord_value); } @@ -84,12 +84,12 @@ std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descript bool SplitLayerNode::forward_descriptors() { - if(input_id(0) != NullTensorID) + if (input_id(0) != NullTensorID) { validate(); - for(unsigned int i = 0; i < _outputs.size(); ++i) + for (unsigned int i = 0; i < _outputs.size(); ++i) { - if(output_id(i) != NullTensorID) + if (output_id(i) != NullTensorID) { Tensor *dst_i = output(i); ARM_COMPUTE_ERROR_ON(dst_i == nullptr); @@ -117,10 +117,10 @@ TensorDescriptor SplitLayerNode::configure_output(size_t idx) const int tmp_axis = wrap_around(_axis, num_dimension); int split_size = (_size_splits.empty()) ? (input_descriptor.shape[tmp_axis] / _num_splits) : _size_splits[idx]; - if(split_size == -1) + if (split_size == -1) { split_size = input_descriptor.shape[tmp_axis]; - for(unsigned int i = 0; i < _size_splits.size() - 1; ++i) + for (unsigned int i = 0; i < _size_splits.size() - 1; ++i) split_size -= _size_splits[i]; } output_descriptor.shape.set(tmp_axis, split_size); @@ -138,7 +138,7 @@ Status SplitLayerNode::validate() const // Handle negative axis, negative index is used to specify axis from the end (e.g. -1 for the last axis). int tmp_axis = wrap_around(_axis, num_dimension); - if(_size_splits.empty()) + if (_size_splits.empty()) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->desc().shape[tmp_axis] % _num_splits, "Split should be exact"); } @@ -156,4 +156,4 @@ void SplitLayerNode::accept(INodeVisitor &v) v.visit(*this); } } // namespace graph -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |