diff options
Diffstat (limited to 'arm_compute/core/utils/misc')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 9bf6b046b4..e5516ba154 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -414,6 +414,31 @@ inline TensorShape get_shape_from_info(ITensorInfo *info) } template <typename T> +inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector) +{ + TensorShape out_shape = get_shape_from_info(inputs_vector[0]); + + size_t max_x = 0; + size_t max_y = 0; + size_t depth = 0; + + for(const auto &tensor : inputs_vector) + { + ARM_COMPUTE_ERROR_ON(tensor == nullptr); + const TensorShape shape = get_shape_from_info(tensor); + max_x = std::max(shape.x(), max_x); + max_y = std::max(shape.y(), max_y); + depth += shape.z(); + } + + out_shape.set(0, max_x); + out_shape.set(1, max_y); + out_shape.set(2, depth); + + return out_shape; +} + +template <typename T> inline TensorShape calculate_width_concatenate_shape(const std::vector<T *> &inputs_vector) { TensorShape out_shape = get_shape_from_info(inputs_vector[0]); |