diff options
Diffstat (limited to 'arm_compute/runtime/CPP/functions/CPPSplit.h')
-rw-r--r-- | arm_compute/runtime/CPP/functions/CPPSplit.h | 64 |
1 files changed, 33 insertions, 31 deletions
diff --git a/arm_compute/runtime/CPP/functions/CPPSplit.h b/arm_compute/runtime/CPP/functions/CPPSplit.h index 6adcbc3323..9be081f5bb 100644 --- a/arm_compute/runtime/CPP/functions/CPPSplit.h +++ b/arm_compute/runtime/CPP/functions/CPPSplit.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 ARM Limited. + * Copyright (c) 2020-2021,2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,9 +29,6 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" - -#include "support/ToolchainSupport.h" - #include "arm_compute/runtime/IFunction.h" namespace arm_compute @@ -41,14 +38,13 @@ template <typename SliceType, typename TensorInterfaceType = ITensor> class CPPSplit : public IFunction { public: - CPPSplit() - : _outputs_vector(), _slice_functions(), _num_outputs(0) + CPPSplit() : _outputs_vector(), _slice_functions(), _num_outputs(0) { } /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit * - * @param[in] input The input tensor info. Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32. - * @param[in] outputs A vector containing the output tensors' info. Data types supported: Same as @p input. + * @param[in] input The input tensor info. Data types supported: All. + * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input. * The output tensors should match the input tensor dimensions for all shape dimensions apart * from the split dimension * @param[in] axis Axis on which to split the input. @@ -66,14 +62,16 @@ public: unsigned int total_output_shape_size = 0; // Sum the output sizes and fall back to evenly-sized splits if any are zero - const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info) - { - unsigned int output_shape_size = info->tensor_shape().total_size(); - total_output_shape_size += output_shape_size; - return output_shape_size == 0; - }); - - if(using_split_shapes) + const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), + [&total_output_shape_size](ITensorInfo *info) + { + unsigned int output_shape_size = + info->tensor_shape().total_size(); + total_output_shape_size += output_shape_size; + return output_shape_size == 0; + }); + + if (using_split_shapes) { ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size); } @@ -85,10 +83,10 @@ public: // Validate output tensors unsigned int axis_offset = 0; - for(const auto &output : outputs) + for (const auto &output : outputs) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - if(using_split_shapes) + if (using_split_shapes) { output_shape = output->tensor_shape(); ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0); @@ -99,14 +97,17 @@ public: // Start/End coordinates Coordinates start_coords; Coordinates end_coords; - for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d) + for (unsigned int d = 0; d < output_shape.num_dimensions(); ++d) { end_coords.set(d, -1); } // Output auto inizialitation if not yet initialized TensorInfo tmp_output_info = *output->clone(); - auto_init_if_empty(tmp_output_info, input->clone()->set_is_resizable(true).set_tensor_shape(output_shape)); + if (tmp_output_info.tensor_shape().total_size() == 0) + { + tmp_output_info = input->clone()->set_is_resizable(true).set_tensor_shape(output_shape); + } // Update coordinate on axis start_coords.set(axis, axis_offset); @@ -127,7 +128,8 @@ public: * from the split dimension. * @param[in] axis Axis on which to split the input. */ - void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis) + void + configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis) { // Create Slice functions _num_outputs = outputs.size(); @@ -135,17 +137,16 @@ public: // Extract output tensor info std::vector<ITensorInfo *> outputs_info; - for(auto &output : outputs) + for (auto &output : outputs) { ARM_COMPUTE_ERROR_ON_NULLPTR(output); outputs_info.emplace_back(output->info()); } // If any of the outputs have a zero size, fall-back to using evenly-sized output splits - const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info) - { - return info->tensor_shape().total_size() == 0; - }); + const bool outputs_have_sizes = + std::none_of(outputs_info.begin(), outputs_info.end(), + [](ITensorInfo *info) { return info->tensor_shape().total_size() == 0; }); // Validate ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis)); @@ -153,12 +154,13 @@ public: unsigned int axis_offset = 0; unsigned int i = 0; - for(const auto &output_info : outputs_info) + for (const auto &output_info : outputs_info) { // Get output shape - TensorShape output_shape = (outputs_have_sizes ? - output_info->tensor_shape() : - arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs)); + TensorShape output_shape = + (outputs_have_sizes + ? output_info->tensor_shape() + : arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs)); const size_t axis_split_step = output_shape[axis]; @@ -166,7 +168,7 @@ public: Coordinates start_coords; Coordinates end_coords; - for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d) + for (unsigned int d = 0; d < output_shape.num_dimensions(); ++d) { end_coords.set(d, -1); } |