From 9e4824c909b14dbaf7106e9527b0ffa22ef09bdc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 12 Apr 2019 13:15:58 +0100 Subject: COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum Alters the concatenate layer to be layout agnostic and accept an index as thec concatenation axis instead of an typed layout dependent enumeration. Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/994 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- src/graph/GraphBuilder.cpp | 34 ++++++++++++---------- src/graph/Utils.cpp | 12 ++++---- src/graph/mutators/DepthConcatSubTensorMutator.cpp | 4 +-- src/graph/mutators/GroupedConvolutionMutator.cpp | 6 ++-- src/graph/nodes/ConcatenateLayerNode.cpp | 2 +- src/graph/nodes/ConvolutionLayerNode.cpp | 9 +++--- src/graph/nodes/DeconvolutionLayerNode.cpp | 9 +++--- src/graph/nodes/DepthwiseConvolutionLayerNode.cpp | 7 +++-- .../FusedConvolutionBatchNormalizationNode.cpp | 7 +++-- src/graph/nodes/PoolingLayerNode.cpp | 7 +++-- src/graph/nodes/ReorgLayerNode.cpp | 9 +++--- src/graph/nodes/ResizeLayerNode.cpp | 7 +++-- src/graph/nodes/UpsampleLayerNode.cpp | 7 +++-- 13 files changed, 66 insertions(+), 54 deletions(-) (limited to 'src/graph') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index b96a242acf..9f8dd69922 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -221,14 +221,15 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa // Get input tensor descriptor const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const DataLayout input_data_layout = input_tensor_desc.layout; // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL), get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::BATCHES), depth); if(!weights_quant_info.empty()) { w_desc.quant_info = weights_quant_info; @@ -275,14 +276,15 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx // Get input tensor descriptor const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const DataLayout input_data_layout = input_tensor_desc.layout; // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL), get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL)); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::BATCHES), depth); NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor)); @@ -328,12 +330,13 @@ NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, // Get input tensor descriptor const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const DataLayout input_data_layout = input_tensor_desc.layout; // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); - w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL), get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) * depth_multiplier); if(!quant_info.empty()) { @@ -595,13 +598,14 @@ NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams ¶ms, NodeIdx // Get input tensor descriptor const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const DataLayout input_data_layout = input_tensor_desc.layout; // Create mul node TensorDescriptor mul_desc = input_tensor_desc; - const size_t C = input_tensor_desc.shape[get_dimension_idx(mul_desc, DataLayoutDimension::CHANNEL)]; - mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), 1); - mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), 1); - mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), C); + const size_t C = input_tensor_desc.shape[get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL)]; + mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), 1); + mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), 1); + mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL), C); NodeID mul_const_nid = add_const_node_with_name(g, params, "Mul", mul_desc, std::move(mul_accessor)); NodeIdxPair mul_const_nidxp = { mul_const_nid, 0 }; diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp index 71a6fc582b..4c34dd85a5 100644 --- a/src/graph/Utils.cpp +++ b/src/graph/Utils.cpp @@ -119,12 +119,12 @@ void setup_requested_backend_context(GraphContext &ctx, Target target) size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension) { ARM_COMPUTE_ERROR_ON_MSG(descriptor.layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); - return descriptor.shape[get_dimension_idx(descriptor, data_layout_dimension)]; + return descriptor.shape[get_dimension_idx(descriptor.layout, data_layout_dimension)]; } -size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension) +size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension) { - ARM_COMPUTE_ERROR_ON_MSG(descriptor.layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); + ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); /* Return the index based on the data layout * [N C H W] @@ -134,13 +134,13 @@ size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDim switch(data_layout_dimension) { case DataLayoutDimension::CHANNEL: - return (descriptor.layout == DataLayout::NCHW) ? 2 : 0; + return (data_layout == DataLayout::NCHW) ? 2 : 0; break; case DataLayoutDimension::HEIGHT: - return (descriptor.layout == DataLayout::NCHW) ? 1 : 2; + return (data_layout == DataLayout::NCHW) ? 1 : 2; break; case DataLayoutDimension::WIDTH: - return (descriptor.layout == DataLayout::NCHW) ? 0 : 1; + return (data_layout == DataLayout::NCHW) ? 0 : 1; break; case DataLayoutDimension::BATCHES: return 3; diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp index 0e0a26b886..7994541b78 100644 --- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp +++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp @@ -62,9 +62,9 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) // Get output tensor auto output_tensor = node->output(0); - // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2) + // Check concatenation axis (Sub-tensor optimization is supported for concatenation axis >=2) auto *concat_node = arm_compute::utils::cast::polymorphic_downcast(node); - if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2) + if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc().layout, concat_node->concatenation_axis()) < 2) { continue; } diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp index d69d2cd7d0..3d53f49218 100644 --- a/src/graph/mutators/GroupedConvolutionMutator.cpp +++ b/src/graph/mutators/GroupedConvolutionMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,12 +47,12 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai // Split input const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); - const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL); + const unsigned int input_idx = get_dimension_idx(input_tensor_desc.layout, DataLayoutDimension::CHANNEL); NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx); // Split weights const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]); - const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES); + const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc.layout, DataLayoutDimension::BATCHES); NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx); // Split bias diff --git a/src/graph/nodes/ConcatenateLayerNode.cpp b/src/graph/nodes/ConcatenateLayerNode.cpp index bbdc4dc029..48da8b6e9e 100644 --- a/src/graph/nodes/ConcatenateLayerNode.cpp +++ b/src/graph/nodes/ConcatenateLayerNode.cpp @@ -67,7 +67,7 @@ TensorDescriptor ConcatenateLayerNode::compute_output_descriptor(const std::vect ARM_COMPUTE_ERROR_ON(input_descriptors.size() == 0); TensorDescriptor output_descriptor = input_descriptors[0]; - const int axis_idx = get_dimension_idx(output_descriptor, axis); + const int axis_idx = get_dimension_idx(output_descriptor.layout, axis); // Extract shapes std::vector shapes; diff --git a/src/graph/nodes/ConvolutionLayerNode.cpp b/src/graph/nodes/ConvolutionLayerNode.cpp index 15c7ff68f8..1c8dcaecfc 100644 --- a/src/graph/nodes/ConvolutionLayerNode.cpp +++ b/src/graph/nodes/ConvolutionLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -97,10 +97,11 @@ TensorDescriptor ConvolutionLayerNode::compute_output_descriptor(const TensorDes std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, kernel_width, kernel_height, info); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), output_width); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), output_height); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); return output_descriptor; } diff --git a/src/graph/nodes/DeconvolutionLayerNode.cpp b/src/graph/nodes/DeconvolutionLayerNode.cpp index e7ccffd04f..b1a6db7ccc 100644 --- a/src/graph/nodes/DeconvolutionLayerNode.cpp +++ b/src/graph/nodes/DeconvolutionLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -66,10 +66,11 @@ TensorDescriptor DeconvolutionLayerNode::compute_output_descriptor(const TensorD info.pad().first, info.pad().second, info.stride().first, info.stride().second); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), output_width); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), output_height); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); return output_descriptor; } diff --git a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp index 935902d3fd..cdd9e7b601 100644 --- a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp +++ b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp @@ -86,10 +86,11 @@ TensorDescriptor DepthwiseConvolutionLayerNode::compute_output_descriptor(const std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, kernel_width, kernel_height, info); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), input_channels * depth_multiplier); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), output_width); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), output_height); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), input_channels * depth_multiplier); return output_descriptor; } diff --git a/src/graph/nodes/FusedConvolutionBatchNormalizationNode.cpp b/src/graph/nodes/FusedConvolutionBatchNormalizationNode.cpp index 27a348fa69..c304a6c605 100644 --- a/src/graph/nodes/FusedConvolutionBatchNormalizationNode.cpp +++ b/src/graph/nodes/FusedConvolutionBatchNormalizationNode.cpp @@ -102,10 +102,11 @@ TensorDescriptor FusedConvolutionBatchNormalizationNode::compute_output_descript std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, kernel_width, kernel_height, info); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), output_width); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), output_height); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]); return output_descriptor; } diff --git a/src/graph/nodes/PoolingLayerNode.cpp b/src/graph/nodes/PoolingLayerNode.cpp index 26c145ae31..48b93c9158 100644 --- a/src/graph/nodes/PoolingLayerNode.cpp +++ b/src/graph/nodes/PoolingLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -57,9 +57,10 @@ TensorDescriptor PoolingLayerNode::compute_output_descriptor(const TensorDescrip std::tie(pooled_width, pooled_height) = scaled_dimensions(input_width, input_height, pool_size_x, pool_size_y, info.pad_stride_info()); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), pooled_width); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), pooled_height); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), pooled_width); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), pooled_height); return output_descriptor; } diff --git a/src/graph/nodes/ReorgLayerNode.cpp b/src/graph/nodes/ReorgLayerNode.cpp index 6b83f6b90c..21ad451c3e 100644 --- a/src/graph/nodes/ReorgLayerNode.cpp +++ b/src/graph/nodes/ReorgLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,10 +53,11 @@ TensorDescriptor ReorgLayerNode::compute_output_descriptor(const TensorDescripto ARM_COMPUTE_ERROR_ON_MSG((input_width % stride != 0), "The width of the input tensor must be a multiple of stride"); ARM_COMPUTE_ERROR_ON_MSG((input_height % stride != 0), "The height of the input tensor must be a multiple of stride"); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), input_width / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), input_height / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), input_channel * stride * stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), input_width / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), input_height / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), input_channel * stride * stride); return output_descriptor; } diff --git a/src/graph/nodes/ResizeLayerNode.cpp b/src/graph/nodes/ResizeLayerNode.cpp index a6aa7bfe5c..a399229013 100644 --- a/src/graph/nodes/ResizeLayerNode.cpp +++ b/src/graph/nodes/ResizeLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -68,9 +68,10 @@ TensorDescriptor ResizeLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); + const DataLayout data_layout = src->desc().layout; TensorDescriptor output_desc = src->desc(); - size_t width_idx = get_dimension_idx(output_desc, DataLayoutDimension::WIDTH); - size_t height_idx = get_dimension_idx(output_desc, DataLayoutDimension::HEIGHT); + size_t width_idx = get_dimension_idx(data_layout, DataLayoutDimension::WIDTH); + size_t height_idx = get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT); output_desc.shape.set(width_idx, static_cast(output_desc.shape[width_idx] * _scale_width)); output_desc.shape.set(height_idx, static_cast(output_desc.shape[height_idx] * _scale_height)); diff --git a/src/graph/nodes/UpsampleLayerNode.cpp b/src/graph/nodes/UpsampleLayerNode.cpp index bdd39e8ebd..88af122a59 100644 --- a/src/graph/nodes/UpsampleLayerNode.cpp +++ b/src/graph/nodes/UpsampleLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -54,9 +54,10 @@ TensorDescriptor UpsampleLayerNode::compute_output_descriptor(const TensorDescri const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH); const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), input_width * info.x()); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), input_height * info.y()); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), input_width * info.x()); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), input_height * info.y()); return output_descriptor; } -- cgit v1.2.1