diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 9d36405041..6782cda7fe 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1222,30 +1222,30 @@ inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inp return out_shape; } -/** Calculate the width concatenate output shape of a vector of tensors +/** Calculate the concatenate output shape of the concatenate operation along a single axis * - * @param[in] inputs_vector Vector containing the shapes of the inputs + * @param[in] input Vector containing the shapes of the inputs + * @param[in] axis Axis along which to concatenate the input tensors * * @return the calculated shape */ template <typename T> -inline TensorShape calculate_width_concatenate_shape(const std::vector<T *> &inputs_vector) +inline TensorShape calculate_concatenate_shape(const std::vector<T *> &input, size_t axis) { - TensorShape out_shape = extract_shape(inputs_vector[0]); + TensorShape out_shape = extract_shape(input[0]); - size_t width = 0; - for(const auto &tensor : inputs_vector) + size_t new_size = 0; + for(const auto &tensor : input) { ARM_COMPUTE_ERROR_ON(tensor == nullptr); const TensorShape shape = extract_shape(tensor); - width += shape.x(); + new_size += shape[axis]; } - out_shape.set(0, width); + out_shape.set(axis, new_size); return out_shape; } - /** Calculate the stack output shape of a tensor * * @param[in] a Input tensor info |