diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLConcatenateLayer.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLConcatenateLayer.h | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h index b69930c7d3..fb9724d167 100644 --- a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h +++ b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h @@ -60,7 +60,8 @@ public: * @param[out] output Output tensor. Data types supported: Same as @p input. * @param[in] axis Concatenation axis. Supported underlying concatenation axis are 0, 1, 2 and 3. */ - void configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis); + void configure(std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis); + void configure(std::vector<const ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis); /** Static function to check if given info will lead to a valid configuration of @ref CLConcatenateLayer * * @note Input and output tensor dimensions preconditions defer depending on the concatenation axis. @@ -73,11 +74,18 @@ public: * @return a status */ static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis); + static Status validate(const std::vector<const ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis); // Inherited methods overridden: void run() override; private: + template <typename TensorType> + void configure_internal(std::vector<TensorType *> &&inputs_vector, ICLTensor *output, size_t axis); + + template <typename TensorInfoType> + static Status validate_internal(const std::vector<TensorInfoType *> &inputs_vector, const ITensorInfo *output, size_t axis); + std::vector<std::unique_ptr<ICLKernel>> _concat_kernels; unsigned int _num_inputs; unsigned int _axis; |