diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index e88fd8d75e..6d8e15b8b2 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -176,13 +176,21 @@ inline TensorShape compute_col2im_shape(const ITensorInfo &input, const Size2D & ARM_COMPUTE_ERROR_ON(input.tensor_shape()[1] != (convolved_dims.area())); ARM_COMPUTE_ERROR_ON((num_groups > 1) && input.tensor_shape()[2] != num_groups); - TensorShape col2im_shape{ input.tensor_shape() }; - col2im_shape.set(0, convolved_dims.width); - col2im_shape.set(1, convolved_dims.height); - col2im_shape.set(2, input.tensor_shape()[0] * num_groups); + const DataLayout data_layout = input.data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const unsigned int batch_idx = (batch_size_on_z && num_groups == 1) ? 2 : 3; - col2im_shape.set(3, input.tensor_shape()[batch_idx]); + TensorShape col2im_shape{ input.tensor_shape() }; + // If batches start on 3rd dimension shift dimensions right by 1 to retain upper tensor shape, + // as first three will be override by H,W,C data + if(batch_size_on_z && num_groups == 1) + { + col2im_shape.shift_right(1); + } + col2im_shape.set(width_idx, convolved_dims.width); + col2im_shape.set(height_idx, convolved_dims.height); + col2im_shape.set(channel_idx, input.tensor_shape()[0] * num_groups); return col2im_shape; } |