diff options
author | steniu01 <steven.niu@arm.com> | 2017-09-29 14:55:00 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 7ce53c620b50c718bac62017d28072cf61457233 (patch) | |
tree | 80a4336ce63f7794986777e8434e9e440478f70d /arm_compute/core/Utils.h | |
parent | b482ce1d601a777250f28ed118ac250943aca4eb (diff) | |
download | ComputeLibrary-7ce53c620b50c718bac62017d28072cf61457233.tar.gz |
COMPMID-546 Add auto config to depth concatenate
Change-Id: I7798a56677d541338a73e3888ed0a2cfe0375794
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/89726
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'arm_compute/core/Utils.h')
-rw-r--r-- | arm_compute/core/Utils.h | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index ab5d110f91..06d674644b 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -35,6 +35,7 @@ #include <string> #include <type_traits> #include <utility> +#include <vector> namespace arm_compute { @@ -419,6 +420,37 @@ inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matri return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0))); } +/** Calculate the output shapes of the depth concatenate function. + * + * @param[in] inputs_vector The vector that stores all the pointers to input. + * + * @return the output shape + */ +template <typename T> +TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector) +{ + TensorShape out_shape = inputs_vector[0]->info()->tensor_shape(); + + size_t max_x = 0; + size_t max_y = 0; + size_t depth = 0; + + for(const auto &tensor : inputs_vector) + { + ARM_COMPUTE_ERROR_ON(tensor == nullptr); + const TensorShape shape = tensor->info()->tensor_shape(); + max_x = std::max(shape.x(), max_x); + max_y = std::max(shape.y(), max_y); + depth += shape.z(); + } + + out_shape.set(0, max_x); + out_shape.set(1, max_y); + out_shape.set(2, depth); + + return out_shape; +} + /** Calculate accurary required by the horizontal and vertical convolution computations * * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter |