diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-04-27 19:07:19 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | cac13b1cfd593889271f8e2191be2039b8d88f36 (patch) | |
tree | d1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /src/graph/nodes/SplitLayerNode.cpp | |
parent | ad0c7388f6261989a268ffb2d042f2bd80736e3f (diff) | |
download | ComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz |
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/nodes/SplitLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/SplitLayerNode.cpp | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/src/graph/nodes/SplitLayerNode.cpp b/src/graph/nodes/SplitLayerNode.cpp index c8fb43c2a1..5d46c9dcc9 100644 --- a/src/graph/nodes/SplitLayerNode.cpp +++ b/src/graph/nodes/SplitLayerNode.cpp @@ -48,26 +48,25 @@ unsigned int SplitLayerNode::axis() const return _axis; } -std::pair<TensorShape, Coordinates> SplitLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_splits, unsigned int axis, unsigned int idx) +std::pair<TensorDescriptor, Coordinates> SplitLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_splits, unsigned int axis, unsigned int idx) { - ARM_COMPUTE_ERROR_ON(axis >= input_shape.num_dimensions()); - ARM_COMPUTE_ERROR_ON_MSG(input_shape[axis] % num_splits, "Split should be exact"); + const unsigned int split_size = input_descriptor.shape[axis] / num_splits; - const unsigned int split_size = input_shape[axis] / num_splits; - - TensorShape output_shape = input_shape; - output_shape.set(axis, split_size); + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape.set(axis, split_size); Coordinates coords; coords.set(axis, idx * split_size); - return std::make_pair(output_shape, coords); + return std::make_pair(output_descriptor, coords); } bool SplitLayerNode::forward_descriptors() { if(input_id(0) != NullTensorID) { + validate(); for(unsigned int i = 0; i < _outputs.size(); ++i) { if(output_id(i) != NullTensorID) @@ -90,17 +89,19 @@ TensorDescriptor SplitLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); - TensorShape output_shape; - - TensorDescriptor output_info = src->desc(); - std::tie(output_shape, std::ignore) = compute_output_shape(src->desc().shape, _num_splits, _axis, idx); - output_info.shape = output_shape; + TensorDescriptor output_info; + std::tie(output_info, std::ignore) = compute_output_descriptor(src->desc(), _num_splits, _axis, idx); return output_info; } -Status SplitLayerNode::validate() +Status SplitLayerNode::validate() const { + const Tensor *src = input(0); + ARM_COMPUTE_RETURN_ERROR_ON(src == nullptr); + ARM_COMPUTE_RETURN_ERROR_ON(_axis >= src->desc().shape.num_dimensions()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->desc().shape[_axis] % _num_splits, "Split should be exact"); + return Status{}; } |