aboutsummaryrefslogtreecommitdiff
path: root/src/core/Helpers.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-05-08 15:54:53 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:19 +0000
commit393fa4c87c84356132303170d1b9ce9a45b3c3bf (patch)
treeb5d5a7ca835d625b5afd56155be8ad9de7ab6575 /src/core/Helpers.cpp
parent1731d5133f1b081fc669d082ae8c3e744d36ab11 (diff)
downloadComputeLibrary-393fa4c87c84356132303170d1b9ce9a45b3c3bf.tar.gz
COMPMID-814: NEScale NHWC support
Change-Id: Ibf5c624a5c5482faa42eb02bc8abe9ae0d65b0d1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130608 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/Helpers.cpp')
-rw-r--r--src/core/Helpers.cpp28
1 files changed, 16 insertions, 12 deletions
diff --git a/src/core/Helpers.cpp b/src/core/Helpers.cpp
index c39922bf03..e336331663 100644
--- a/src/core/Helpers.cpp
+++ b/src/core/Helpers.cpp
@@ -177,21 +177,25 @@ Window arm_compute::calculate_max_window_horizontal(const ValidRegion &valid_reg
ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape,
InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined)
{
- const float scale_x = static_cast<float>(dst_shape[0]) / src_info.tensor_shape()[0];
- const float scale_y = static_cast<float>(dst_shape[1]) / src_info.tensor_shape()[1];
+ const DataLayout data_layout = src_info.data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+ const float scale_x = static_cast<float>(dst_shape[idx_width]) / src_info.tensor_shape()[idx_width];
+ const float scale_y = static_cast<float>(dst_shape[idx_height]) / src_info.tensor_shape()[idx_height];
const float sampling_point = (sampling_policy == SamplingPolicy::CENTER) ? 0.5f : 0.0f;
// Get input's valid region start and end points
- const int valid_start_in_x = src_info.valid_region().anchor[0];
- const int valid_start_in_y = src_info.valid_region().anchor[1];
- const int valid_end_in_x = src_info.valid_region().anchor[0] + src_info.valid_region().shape[0];
- const int valid_end_in_y = src_info.valid_region().anchor[1] + src_info.valid_region().shape[1];
+ const int valid_start_in_x = src_info.valid_region().anchor[idx_width];
+ const int valid_start_in_y = src_info.valid_region().anchor[idx_height];
+ const int valid_end_in_x = src_info.valid_region().anchor[idx_width] + src_info.valid_region().shape[idx_width];
+ const int valid_end_in_y = src_info.valid_region().anchor[idx_height] + src_info.valid_region().shape[idx_height];
// Initialize output's valid region start and end points
auto valid_start_out_x = static_cast<int>(valid_start_in_x * scale_x);
auto valid_start_out_y = static_cast<int>(valid_start_in_y * scale_y);
- auto valid_end_out_x = std::min<int>(std::ceil(valid_end_in_x * scale_x), dst_shape[0]);
- auto valid_end_out_y = std::min<int>(std::ceil(valid_end_in_y * scale_y), dst_shape[1]);
+ auto valid_end_out_x = std::min<int>(std::ceil(valid_end_in_x * scale_x), dst_shape[idx_width]);
+ auto valid_end_out_y = std::min<int>(std::ceil(valid_end_in_y * scale_y), dst_shape[idx_height]);
// Handle valid points in case of the bi-linear interpolation
if(border_undefined)
@@ -237,11 +241,11 @@ ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_inf
// Setup output valid region
ValidRegion valid_region{ Coordinates(), dst_shape, src_info.tensor_shape().num_dimensions() };
- valid_region.anchor.set(0, std::max(0, valid_start_out_x));
- valid_region.anchor.set(1, std::max(0, valid_start_out_y));
+ valid_region.anchor.set(idx_width, std::max(0, valid_start_out_x));
+ valid_region.anchor.set(idx_height, std::max(0, valid_start_out_y));
- valid_region.shape.set(0, std::min<size_t>(valid_end_out_x - valid_start_out_x, dst_shape[0]));
- valid_region.shape.set(1, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[1]));
+ valid_region.shape.set(idx_width, std::min<size_t>(valid_end_out_x - valid_start_out_x, dst_shape[idx_width]));
+ valid_region.shape.set(idx_height, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[idx_height]));
return valid_region;
} \ No newline at end of file