diff options
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 14 |
1 files changed, 4 insertions, 10 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 1e5b9afd0e..0a2a535502 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -215,19 +215,13 @@ inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Siz return output_shape; } -inline TensorShape compute_im2col_fc_shape(const ITensorInfo *input, const int num_input_dimensions = 3) +inline TensorShape compute_flatten_shape(const ITensorInfo *input) { - TensorShape output_shape{ input->tensor_shape() }; - - output_shape.collapse(num_input_dimensions); + // The output shape will be the flatten version of the input (i.e. [ width * height * channels, num_batches, ... ] ). Used for FlattenLayer and FullyConnectedLayer. - return output_shape; -} -inline TensorShape compute_im2col_flatten_shape(const ITensorInfo *input) -{ - // The output shape will be the flatten version of the input (i.e. [ width * height * channels, 1, 1, ... ] ). Used for FlattenLayer. TensorShape output_shape{ input->tensor_shape() }; - output_shape.collapse(3, 0); + + output_shape.collapse(3); return output_shape; } |