diff options
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index b91e52a657..8d4c024f62 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -168,9 +168,9 @@ inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Siz const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation); - output_shape.set(width_idx, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0))); - output_shape.set(height_idx, (out_dims.first * out_dims.second)); - output_shape.set(channel_idx, 1); + output_shape.set(0, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0))); + output_shape.set(1, (out_dims.first * out_dims.second)); + output_shape.set(2, 1); return output_shape; } |