diff options
Diffstat (limited to 'src/runtime/CL/functions/CLConcatenateLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLConcatenateLayer.cpp | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/src/runtime/CL/functions/CLConcatenateLayer.cpp b/src/runtime/CL/functions/CLConcatenateLayer.cpp index 7edea3efac..b9b3c5bb80 100644 --- a/src/runtime/CL/functions/CLConcatenateLayer.cpp +++ b/src/runtime/CL/functions/CLConcatenateLayer.cpp @@ -44,10 +44,10 @@ CLConcatenateLayer::CLConcatenateLayer() { } -void CLConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, DataLayoutDimension axis) +void CLConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis) { ARM_COMPUTE_ERROR_ON(output == nullptr); - _axis = get_data_layout_dimension_index(output->info()->data_layout(), axis); + _axis = axis; _num_inputs = inputs_vector.size(); std::vector<ITensorInfo *> inputs_vector_info(inputs_vector.size()); @@ -135,30 +135,29 @@ void CLConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector } } -Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis) +Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis) { ARM_COMPUTE_RETURN_ERROR_ON(output == nullptr); const unsigned int num_inputs = inputs_vector.size(); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); ARM_COMPUTE_RETURN_ERROR_ON(num_inputs < 2); - const unsigned int _axis = get_data_layout_dimension_index(inputs_vector[0]->data_layout(), axis); // Output auto inizialitation if not yet initialized TensorInfo tmp_output_info = *output->clone(); TensorShape output_shape{}; - if(_axis == Window::DimZ) + if(axis == Window::DimZ) { output_shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(inputs_vector); } else { - output_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, _axis); + output_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, axis); } auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type()); unsigned int offset = 0; - switch(_axis) + switch(axis) { case Window::DimX: { @@ -180,7 +179,7 @@ Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vec { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenateLayerKernel::validate(input, offset, &tmp_output_info)); - offset += input->dimension(_axis); + offset += input->dimension(axis); } break; } @@ -191,7 +190,7 @@ Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vec for(const auto &input : inputs_vector) { ARM_COMPUTE_RETURN_ON_ERROR(CLHeightConcatenateLayerKernel::validate(input, offset, &tmp_output_info)); - offset += input->dimension(_axis); + offset += input->dimension(axis); } break; } @@ -200,7 +199,7 @@ Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vec for(const auto &input : inputs_vector) { ARM_COMPUTE_RETURN_ON_ERROR(CLDepthConcatenateLayerKernel::validate(input, offset, &tmp_output_info)); - offset += input->dimension(_axis); + offset += input->dimension(axis); } break; } |