diff options
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/Utils.h | 10 | ||||
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 17 |
2 files changed, 23 insertions, 4 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index 3f04ed9963..3939491bb2 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -881,6 +881,16 @@ std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsi const PadStrideInfo &pad_stride_info, const Size2D &dilation = Size2D(1U, 1U)); +/** Check if the given reduction operation should be handled in a serial way. + * + * @param[in] op Reduction operation to perform + * @param[in] dt Data type + * @param[in] axis Axis along which to reduce + * + * @return True if the given reduction operation should be handled in a serial way. + */ +bool needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis); + /** Convert a tensor format into a string. * * @param[in] format @ref Format to be translated to string. diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index c4c360842f..080d63f60d 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1179,15 +1179,24 @@ inline TensorShape compute_tiled_shape(const TensorShape &input_shape, const Mul /** Calculate the reduced shape of a tensor given an axis * - * @param[in] input Input tensor info - * @param[in] axis Axis on which to perform reduction + * @param[in] input Input tensor info + * @param[in] axis Axis on which to perform reduction + * @param[in] keep_dims (Optional) Whether to keep the dimension after reduction operation. Defaults to true. * * @return the calculated shape */ -inline TensorShape compute_reduced_shape(const TensorShape &input, unsigned int axis) +inline TensorShape compute_reduced_shape(const TensorShape &input, unsigned int axis, bool keep_dims = true) { TensorShape output_shape{ input }; - output_shape.set(axis, 1); + + if(!keep_dims) + { + output_shape.remove_dimension(axis); + } + else + { + output_shape.set(axis, 1); + } return output_shape; } |