From e55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 13 Sep 2018 17:20:04 +0100 Subject: COMPMID-1581: Collapse windows Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572 Reviewed-by: Anthony Barbier Tested-by: bsgcomp --- arm_compute/core/utils/misc/ShapeCalculator.h | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'arm_compute/core/utils/misc') diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index e88fd8d75e..6d8e15b8b2 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -176,13 +176,21 @@ inline TensorShape compute_col2im_shape(const ITensorInfo &input, const Size2D & ARM_COMPUTE_ERROR_ON(input.tensor_shape()[1] != (convolved_dims.area())); ARM_COMPUTE_ERROR_ON((num_groups > 1) && input.tensor_shape()[2] != num_groups); - TensorShape col2im_shape{ input.tensor_shape() }; - col2im_shape.set(0, convolved_dims.width); - col2im_shape.set(1, convolved_dims.height); - col2im_shape.set(2, input.tensor_shape()[0] * num_groups); + const DataLayout data_layout = input.data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const unsigned int batch_idx = (batch_size_on_z && num_groups == 1) ? 2 : 3; - col2im_shape.set(3, input.tensor_shape()[batch_idx]); + TensorShape col2im_shape{ input.tensor_shape() }; + // If batches start on 3rd dimension shift dimensions right by 1 to retain upper tensor shape, + // as first three will be override by H,W,C data + if(batch_size_on_z && num_groups == 1) + { + col2im_shape.shift_right(1); + } + col2im_shape.set(width_idx, convolved_dims.width); + col2im_shape.set(height_idx, convolved_dims.height); + col2im_shape.set(channel_idx, input.tensor_shape()[0] * num_groups); return col2im_shape; } -- cgit v1.2.1