diff options
Diffstat (limited to 'src/graph/nodes/ReorgLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/ReorgLayerNode.cpp | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/src/graph/nodes/ReorgLayerNode.cpp b/src/graph/nodes/ReorgLayerNode.cpp index 6b83f6b90c..21ad451c3e 100644 --- a/src/graph/nodes/ReorgLayerNode.cpp +++ b/src/graph/nodes/ReorgLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,10 +53,11 @@ TensorDescriptor ReorgLayerNode::compute_output_descriptor(const TensorDescripto ARM_COMPUTE_ERROR_ON_MSG((input_width % stride != 0), "The width of the input tensor must be a multiple of stride"); ARM_COMPUTE_ERROR_ON_MSG((input_height % stride != 0), "The height of the input tensor must be a multiple of stride"); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), input_width / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), input_height / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), input_channel * stride * stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), input_width / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), input_height / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), input_channel * stride * stride); return output_descriptor; } |