diff options
-rw-r--r-- | arm_compute/core/Utils.h | 32 | ||||
-rw-r--r-- | src/runtime/CL/functions/CLDepthConcatenate.cpp | 6 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEDepthConcatenate.cpp | 6 |
3 files changed, 44 insertions, 0 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index ab5d110f91..06d674644b 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -35,6 +35,7 @@ #include <string> #include <type_traits> #include <utility> +#include <vector> namespace arm_compute { @@ -419,6 +420,37 @@ inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matri return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0))); } +/** Calculate the output shapes of the depth concatenate function. + * + * @param[in] inputs_vector The vector that stores all the pointers to input. + * + * @return the output shape + */ +template <typename T> +TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector) +{ + TensorShape out_shape = inputs_vector[0]->info()->tensor_shape(); + + size_t max_x = 0; + size_t max_y = 0; + size_t depth = 0; + + for(const auto &tensor : inputs_vector) + { + ARM_COMPUTE_ERROR_ON(tensor == nullptr); + const TensorShape shape = tensor->info()->tensor_shape(); + max_x = std::max(shape.x(), max_x); + max_y = std::max(shape.y(), max_y); + depth += shape.z(); + } + + out_shape.set(0, max_x); + out_shape.set(1, max_y); + out_shape.set(2, depth); + + return out_shape; +} + /** Calculate accurary required by the horizontal and vertical convolution computations * * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter diff --git a/src/runtime/CL/functions/CLDepthConcatenate.cpp b/src/runtime/CL/functions/CLDepthConcatenate.cpp index f42627f34c..89e44ca98e 100644 --- a/src/runtime/CL/functions/CLDepthConcatenate.cpp +++ b/src/runtime/CL/functions/CLDepthConcatenate.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -51,6 +52,11 @@ void CLDepthConcatenate::configure(std::vector<ICLTensor *> inputs_vector, ICLTe _concat_kernels_vector = arm_compute::support::cpp14::make_unique<CLDepthConcatenateKernel[]>(_num_inputs); _border_handlers_vector = arm_compute::support::cpp14::make_unique<CLFillBorderKernel[]>(_num_inputs); + TensorShape output_shape = calculate_depth_concatenate_shape(inputs_vector); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type(), inputs_vector[0]->info()->fixed_point_position()); + for(unsigned int i = 0; i < _num_inputs; i++) { _concat_kernels_vector[i].configure(inputs_vector.at(i), depth_offset, output); diff --git a/src/runtime/NEON/functions/NEDepthConcatenate.cpp b/src/runtime/NEON/functions/NEDepthConcatenate.cpp index 90eee4f45f..f8ad2abe61 100644 --- a/src/runtime/NEON/functions/NEDepthConcatenate.cpp +++ b/src/runtime/NEON/functions/NEDepthConcatenate.cpp @@ -24,6 +24,7 @@ #include "arm_compute/runtime/NEON/functions/NEDepthConcatenate.h" #include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Types.h" @@ -48,6 +49,11 @@ void NEDepthConcatenate::configure(std::vector<ITensor *> inputs_vector, ITensor _concat_kernels_vector = arm_compute::support::cpp14::make_unique<NEDepthConcatenateKernel[]>(_num_inputs); _border_handlers_vector = arm_compute::support::cpp14::make_unique<NEFillBorderKernel[]>(_num_inputs); + TensorShape output_shape = calculate_depth_concatenate_shape(inputs_vector); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type(), inputs_vector[0]->info()->fixed_point_position()); + unsigned int depth_offset = 0; for(unsigned int i = 0; i < _num_inputs; ++i) { |