diff options
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 36 | ||||
-rw-r--r-- | arm_compute/core/utils/misc/Utility.h | 16 |
2 files changed, 51 insertions, 1 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 65a2a1edf4..698a2b7a45 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -39,6 +39,42 @@ namespace misc { namespace shape_calculator { +/** Calculate the output tensor shape for the reduce mean operation + * + * @param[in] input Input tensor shape + * @param[in] reduction_axis Reduction axis + * @param[in] keep_dims Flag to indicate if dimensions are kept + * + * @return the calculated shape + */ +inline TensorShape calculate_reduce_mean_shape(ITensor *input, const Coordinates &reduction_axis, bool keep_dims) +{ + const int reduction_ops = reduction_axis.num_dimensions(); + Coordinates axis_local = reduction_axis; + const int input_dims = input->info()->num_dimensions(); + convert_negative_axis(axis_local, input_dims); + TensorShape out_shape = input->info()->tensor_shape(); + // Configure reshape layer if we want to drop the dimensions + if(!keep_dims) + { + // We have to sort the reduction axis vectors in order for remove_dimension + // to work properly + std::sort(axis_local.begin(), axis_local.begin() + reduction_ops); + for(int i = 0; i < reduction_ops; ++i) + { + out_shape.remove_dimension(axis_local[i] - i); + } + return out_shape; + } + else + { + for(int i = 0; i < reduction_ops; ++i) + { + out_shape.set(axis_local[i], 1); + } + return out_shape; + } +} /** Calculate the output tensor shape of a vector input given the convolution dimensions * * @param[in] input Input tensor shape diff --git a/arm_compute/core/utils/misc/Utility.h b/arm_compute/core/utils/misc/Utility.h index 8dd9afd5cd..2325644e72 100644 --- a/arm_compute/core/utils/misc/Utility.h +++ b/arm_compute/core/utils/misc/Utility.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,20 @@ struct index_sequence_generator<0u, S...> : index_sequence<S...> template <std::size_t N> using index_sequence_t = typename index_sequence_generator<N>::type; + +template <typename T, std::size_t N, T val, T... vals> +struct generate_array : generate_array < T, N - 1, val, val, vals... > +{ +}; + +template <typename T, T val, T... vals> +struct generate_array<T, 0, val, vals...> +{ + static constexpr std::array<T, sizeof...(vals)> value{ vals... }; +}; + +template <typename T, T val, T... vals> +constexpr std::array<T, sizeof...(vals)> generate_array<T, 0, val, vals...>::value; /** @endcond */ namespace detail |