diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/CL/kernels/CLPermuteKernel.cpp | 8 | ||||
-rw-r--r-- | src/core/Helpers.cpp | 34 |
2 files changed, 29 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLPermuteKernel.cpp b/src/core/CL/kernels/CLPermuteKernel.cpp index 1636e5a1bc..dc2d6fe4b4 100644 --- a/src/core/CL/kernels/CLPermuteKernel.cpp +++ b/src/core/CL/kernels/CLPermuteKernel.cpp @@ -75,16 +75,16 @@ void CLPermuteKernel::configure(const ICLTensor *input, ICLTensor *output, const void CLPermuteKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const PermutationVector &perm) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + const TensorShape output_shape = get_output_shape(input->info(), perm); + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), perm)); _input = input; _output = output; _perm = perm; - const TensorShape output_shape = get_output_shape(input->info(), perm); - // Output auto inizialitation if not yet initialized - auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); - // Create kernel CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type()))); diff --git a/src/core/Helpers.cpp b/src/core/Helpers.cpp index bfc4a8d101..5c7200b35c 100644 --- a/src/core/Helpers.cpp +++ b/src/core/Helpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2018 Arm Limited. + * Copyright (c) 2016-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,9 +23,9 @@ */ #include "arm_compute/core/Helpers.h" -using namespace arm_compute; - -Window arm_compute::calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size) +namespace arm_compute +{ +Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size) { if(!skip_border) { @@ -79,7 +79,7 @@ Window arm_compute::calculate_max_window(const ValidRegion &valid_region, const return window; } -Window arm_compute::calculate_max_enlarged_window(const ValidRegion &valid_region, const Steps &steps, BorderSize border_size) +Window calculate_max_enlarged_window(const ValidRegion &valid_region, const Steps &steps, BorderSize border_size) { const Coordinates &anchor = valid_region.anchor; const TensorShape &shape = valid_region.shape; @@ -128,7 +128,7 @@ Window arm_compute::calculate_max_enlarged_window(const ValidRegion &valid_regio return window; } -Window arm_compute::calculate_max_window_horizontal(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size) +Window calculate_max_window_horizontal(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size) { if(skip_border) { @@ -181,8 +181,8 @@ Window arm_compute::calculate_max_window_horizontal(const ValidRegion &valid_reg return window; } -ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, - InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined) +ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, + InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined) { const DataLayout data_layout = src_info.data_layout(); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); @@ -255,4 +255,20 @@ ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_inf valid_region.shape.set(idx_height, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[idx_height])); return valid_region; -}
\ No newline at end of file +} + +PermutationVector get_permutation_vector_from_softmax_axis(size_t actual_axis) +{ + switch(actual_axis) + { + case 1: + return PermutationVector(1U, 0U, 2U, 3U); + case 2: + return PermutationVector(2U, 1U, 0U, 3U); + case 3: + return PermutationVector(3U, 1U, 2U, 0U); + default: + ARM_COMPUTE_ERROR("Axis not supported"); + } +} +} // namespace arm_compute
\ No newline at end of file |