aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h42
1 files changed, 34 insertions, 8 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index c3d5b64a92..e174227302 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -107,13 +107,6 @@ inline TensorShape compute_reductionB_shape(const ITensorInfo &a)
return shape_vector_sum_row;
}
-inline TensorShape compute_im2col_shape(const ITensorInfo &input)
-{
- TensorShape shape_im2col{ input.tensor_shape() };
- shape_im2col.collapse(3);
-
- return shape_im2col;
-}
inline TensorShape compute_col2im_shape(const ITensorInfo &input, std::pair<unsigned int, unsigned int> convolved_dims)
{
TensorShape col2im_shape{ input.tensor_shape() };
@@ -159,7 +152,25 @@ inline TensorShape compute_deconvolution_shape(const ITensorInfo &input, unsigne
return scale_out_shape;
}
-inline TensorShape compute_im2col_shape(const ITensorInfo *input, const int num_input_dimensions = 3)
+inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation)
+{
+ // The output shape will be the 2D shape used as input for GEMM [ out_channels * kernel_area, num_elems_per_out_channel ]
+
+ TensorShape output_shape{ input->tensor_shape() };
+
+ const DataLayout data_layout = input->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
+ std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation);
+ output_shape.set(width_idx, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0)));
+ output_shape.set(height_idx, (out_dims.first * out_dims.second));
+ output_shape.set(channel_idx, 1);
+
+ return output_shape;
+}
+inline TensorShape compute_im2col_fc_shape(const ITensorInfo *input, const int num_input_dimensions = 3)
{
TensorShape output_shape{ input->tensor_shape() };
@@ -167,6 +178,21 @@ inline TensorShape compute_im2col_shape(const ITensorInfo *input, const int num_
return output_shape;
}
+inline TensorShape compute_im2col_flatten_shape(const ITensorInfo *input)
+{
+ // The output shape will be the flatten version of the input (i.e. [ width * height * channels, 1, 1, ... ] ). Used for FlattenLayer.
+
+ ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 3);
+
+ TensorShape output_shape{ input->tensor_shape() };
+
+ const size_t flatten_shape = input->dimension(0) * input->dimension(1) * input->dimension(2);
+ output_shape.set(0, flatten_shape);
+ output_shape.remove_dimension(1);
+ output_shape.remove_dimension(1);
+
+ return output_shape;
+}
inline TensorShape compute_interleave_custom_shape(const TensorShape &input, const int x_interleave, const int y_interleave)
{
TensorShape output_shape{ input };