From cac13b1cfd593889271f8e2191be2039b8d88f36 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 27 Apr 2018 19:07:19 +0100 Subject: COMPMID-1097: Port mobilenet to NHWC Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/graph/nodes/SplitLayerNode.cpp | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) (limited to 'src/graph/nodes/SplitLayerNode.cpp') diff --git a/src/graph/nodes/SplitLayerNode.cpp b/src/graph/nodes/SplitLayerNode.cpp index c8fb43c2a1..5d46c9dcc9 100644 --- a/src/graph/nodes/SplitLayerNode.cpp +++ b/src/graph/nodes/SplitLayerNode.cpp @@ -48,26 +48,25 @@ unsigned int SplitLayerNode::axis() const return _axis; } -std::pair SplitLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_splits, unsigned int axis, unsigned int idx) +std::pair SplitLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_splits, unsigned int axis, unsigned int idx) { - ARM_COMPUTE_ERROR_ON(axis >= input_shape.num_dimensions()); - ARM_COMPUTE_ERROR_ON_MSG(input_shape[axis] % num_splits, "Split should be exact"); + const unsigned int split_size = input_descriptor.shape[axis] / num_splits; - const unsigned int split_size = input_shape[axis] / num_splits; - - TensorShape output_shape = input_shape; - output_shape.set(axis, split_size); + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape.set(axis, split_size); Coordinates coords; coords.set(axis, idx * split_size); - return std::make_pair(output_shape, coords); + return std::make_pair(output_descriptor, coords); } bool SplitLayerNode::forward_descriptors() { if(input_id(0) != NullTensorID) { + validate(); for(unsigned int i = 0; i < _outputs.size(); ++i) { if(output_id(i) != NullTensorID) @@ -90,17 +89,19 @@ TensorDescriptor SplitLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); - TensorShape output_shape; - - TensorDescriptor output_info = src->desc(); - std::tie(output_shape, std::ignore) = compute_output_shape(src->desc().shape, _num_splits, _axis, idx); - output_info.shape = output_shape; + TensorDescriptor output_info; + std::tie(output_info, std::ignore) = compute_output_descriptor(src->desc(), _num_splits, _axis, idx); return output_info; } -Status SplitLayerNode::validate() +Status SplitLayerNode::validate() const { + const Tensor *src = input(0); + ARM_COMPUTE_RETURN_ERROR_ON(src == nullptr); + ARM_COMPUTE_RETURN_ERROR_ON(_axis >= src->desc().shape.num_dimensions()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->desc().shape[_axis] % _num_splits, "Split should be exact"); + return Status{}; } -- cgit v1.2.1