aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h36
-rw-r--r--arm_compute/core/utils/misc/Utility.h16
2 files changed, 51 insertions, 1 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 65a2a1edf4..698a2b7a45 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -39,6 +39,42 @@ namespace misc
{
namespace shape_calculator
{
+/** Calculate the output tensor shape for the reduce mean operation
+ *
+ * @param[in] input Input tensor shape
+ * @param[in] reduction_axis Reduction axis
+ * @param[in] keep_dims Flag to indicate if dimensions are kept
+ *
+ * @return the calculated shape
+ */
+inline TensorShape calculate_reduce_mean_shape(ITensor *input, const Coordinates &reduction_axis, bool keep_dims)
+{
+ const int reduction_ops = reduction_axis.num_dimensions();
+ Coordinates axis_local = reduction_axis;
+ const int input_dims = input->info()->num_dimensions();
+ convert_negative_axis(axis_local, input_dims);
+ TensorShape out_shape = input->info()->tensor_shape();
+ // Configure reshape layer if we want to drop the dimensions
+ if(!keep_dims)
+ {
+ // We have to sort the reduction axis vectors in order for remove_dimension
+ // to work properly
+ std::sort(axis_local.begin(), axis_local.begin() + reduction_ops);
+ for(int i = 0; i < reduction_ops; ++i)
+ {
+ out_shape.remove_dimension(axis_local[i] - i);
+ }
+ return out_shape;
+ }
+ else
+ {
+ for(int i = 0; i < reduction_ops; ++i)
+ {
+ out_shape.set(axis_local[i], 1);
+ }
+ return out_shape;
+ }
+}
/** Calculate the output tensor shape of a vector input given the convolution dimensions
*
* @param[in] input Input tensor shape
diff --git a/arm_compute/core/utils/misc/Utility.h b/arm_compute/core/utils/misc/Utility.h
index 8dd9afd5cd..2325644e72 100644
--- a/arm_compute/core/utils/misc/Utility.h
+++ b/arm_compute/core/utils/misc/Utility.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,20 @@ struct index_sequence_generator<0u, S...> : index_sequence<S...>
template <std::size_t N>
using index_sequence_t = typename index_sequence_generator<N>::type;
+
+template <typename T, std::size_t N, T val, T... vals>
+struct generate_array : generate_array < T, N - 1, val, val, vals... >
+{
+};
+
+template <typename T, T val, T... vals>
+struct generate_array<T, 0, val, vals...>
+{
+ static constexpr std::array<T, sizeof...(vals)> value{ vals... };
+};
+
+template <typename T, T val, T... vals>
+constexpr std::array<T, sizeof...(vals)> generate_array<T, 0, val, vals...>::value;
/** @endcond */
namespace detail