From 9e4824c909b14dbaf7106e9527b0ffa22ef09bdc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 12 Apr 2019 13:15:58 +0100 Subject: COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum Alters the concatenate layer to be layout agnostic and accept an index as thec concatenation axis instead of an typed layout dependent enumeration. Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/994 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- src/graph/nodes/ReorgLayerNode.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'src/graph/nodes/ReorgLayerNode.cpp') diff --git a/src/graph/nodes/ReorgLayerNode.cpp b/src/graph/nodes/ReorgLayerNode.cpp index 6b83f6b90c..21ad451c3e 100644 --- a/src/graph/nodes/ReorgLayerNode.cpp +++ b/src/graph/nodes/ReorgLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,10 +53,11 @@ TensorDescriptor ReorgLayerNode::compute_output_descriptor(const TensorDescripto ARM_COMPUTE_ERROR_ON_MSG((input_width % stride != 0), "The width of the input tensor must be a multiple of stride"); ARM_COMPUTE_ERROR_ON_MSG((input_height % stride != 0), "The height of the input tensor must be a multiple of stride"); + const DataLayout data_layout = input_descriptor.layout; TensorDescriptor output_descriptor = input_descriptor; - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), input_width / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), input_height / stride); - output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), input_channel * stride * stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::WIDTH), input_width / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT), input_height / stride); + output_descriptor.shape.set(get_dimension_idx(data_layout, DataLayoutDimension::CHANNEL), input_channel * stride * stride); return output_descriptor; } -- cgit v1.2.1