aboutsummaryrefslogtreecommitdiff
path: root/src/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/core')
-rw-r--r--src/core/CL/kernels/CLPermuteKernel.cpp8
-rw-r--r--src/core/Helpers.cpp34
2 files changed, 29 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLPermuteKernel.cpp b/src/core/CL/kernels/CLPermuteKernel.cpp
index 1636e5a1bc..dc2d6fe4b4 100644
--- a/src/core/CL/kernels/CLPermuteKernel.cpp
+++ b/src/core/CL/kernels/CLPermuteKernel.cpp
@@ -75,16 +75,16 @@ void CLPermuteKernel::configure(const ICLTensor *input, ICLTensor *output, const
void CLPermuteKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const PermutationVector &perm)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ const TensorShape output_shape = get_output_shape(input->info(), perm);
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
+
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), perm));
_input = input;
_output = output;
_perm = perm;
- const TensorShape output_shape = get_output_shape(input->info(), perm);
- // Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
-
// Create kernel
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type())));
diff --git a/src/core/Helpers.cpp b/src/core/Helpers.cpp
index bfc4a8d101..5c7200b35c 100644
--- a/src/core/Helpers.cpp
+++ b/src/core/Helpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2018 Arm Limited.
+ * Copyright (c) 2016-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,9 +23,9 @@
*/
#include "arm_compute/core/Helpers.h"
-using namespace arm_compute;
-
-Window arm_compute::calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
+namespace arm_compute
+{
+Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
{
if(!skip_border)
{
@@ -79,7 +79,7 @@ Window arm_compute::calculate_max_window(const ValidRegion &valid_region, const
return window;
}
-Window arm_compute::calculate_max_enlarged_window(const ValidRegion &valid_region, const Steps &steps, BorderSize border_size)
+Window calculate_max_enlarged_window(const ValidRegion &valid_region, const Steps &steps, BorderSize border_size)
{
const Coordinates &anchor = valid_region.anchor;
const TensorShape &shape = valid_region.shape;
@@ -128,7 +128,7 @@ Window arm_compute::calculate_max_enlarged_window(const ValidRegion &valid_regio
return window;
}
-Window arm_compute::calculate_max_window_horizontal(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
+Window calculate_max_window_horizontal(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
{
if(skip_border)
{
@@ -181,8 +181,8 @@ Window arm_compute::calculate_max_window_horizontal(const ValidRegion &valid_reg
return window;
}
-ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape,
- InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined)
+ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape,
+ InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined)
{
const DataLayout data_layout = src_info.data_layout();
const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
@@ -255,4 +255,20 @@ ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_inf
valid_region.shape.set(idx_height, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[idx_height]));
return valid_region;
-} \ No newline at end of file
+}
+
+PermutationVector get_permutation_vector_from_softmax_axis(size_t actual_axis)
+{
+ switch(actual_axis)
+ {
+ case 1:
+ return PermutationVector(1U, 0U, 2U, 3U);
+ case 2:
+ return PermutationVector(2U, 1U, 0U, 3U);
+ case 3:
+ return PermutationVector(3U, 1U, 2U, 0U);
+ default:
+ ARM_COMPUTE_ERROR("Axis not supported");
+ }
+}
+} // namespace arm_compute \ No newline at end of file