aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Helpers.inl
diff options
context:
space:
mode:
authorDiego Lopez Recas <Diego.LopezRecas@arm.com>2018-02-22 13:08:01 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:48:11 +0000
commitf03561b1dd1b83d44a5c20b0ff349c428efb716c (patch)
treeb27bc693a33ea65db10a135b1398156b3908a58c /arm_compute/core/Helpers.inl
parent23fe7c2729e657a1886911f01dd544de732ebf61 (diff)
downloadComputeLibrary-f03561b1dd1b83d44a5c20b0ff349c428efb716c.tar.gz
IVGCVSW-1018 Fix valid region for Scale
Change-Id: I28081320fb853e905c545d6ce743f223066d0f8c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/121928 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/core/Helpers.inl')
-rw-r--r--arm_compute/core/Helpers.inl113
1 files changed, 89 insertions, 24 deletions
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl
index 8b86c22676..899805a701 100644
--- a/arm_compute/core/Helpers.inl
+++ b/arm_compute/core/Helpers.inl
@@ -290,33 +290,98 @@ inline bool set_quantization_info_if_empty(ITensorInfo &info, QuantizationInfo q
inline ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, InterpolationPolicy policy, BorderSize border_size, bool border_undefined)
{
- const auto wr = static_cast<float>(dst_shape[0]) / static_cast<float>(src_info.tensor_shape()[0]);
- const auto hr = static_cast<float>(dst_shape[1]) / static_cast<float>(src_info.tensor_shape()[1]);
+ const float scale_x = static_cast<float>(dst_shape[0]) / src_info.tensor_shape()[0];
+ const float scale_y = static_cast<float>(dst_shape[1]) / src_info.tensor_shape()[1];
+
+ int valid_start_out_x;
+ int valid_start_out_y;
+ int valid_end_out_x;
+ int valid_end_out_y;
+
+ switch(policy)
+ {
+ case InterpolationPolicy::NEAREST_NEIGHBOR:
+ {
+ const float valid_start_in_x = src_info.valid_region().anchor[0];
+ const float valid_start_in_y = src_info.valid_region().anchor[1];
+
+ const float valid_end_in_x = src_info.valid_region().anchor[0] + src_info.valid_region().shape[0];
+ const float valid_end_in_y = src_info.valid_region().anchor[1] + src_info.valid_region().shape[1];
+
+ // (start_out + 0.5) / scale >= start_in
+ // --> start_out >= (start_in * scale) - 0.5
+ // --> start_out = ceil((start_in * scale) - 0.5)
+ valid_start_out_x = std::ceil((valid_start_in_x * scale_x) - 0.5f);
+ valid_start_out_y = std::ceil((valid_start_in_y * scale_y) - 0.5f);
+
+ // (end_out - 0.5) / scale < end_in
+ // --> end_out < (end_in * scale) + 0.5
+ // --> end_out = ceil((end_in * scale) + 0.5 - 1)
+ valid_end_out_x = std::ceil((valid_end_in_x * scale_x) - 0.5f);
+ valid_end_out_y = std::ceil((valid_end_in_y * scale_y) - 0.5f);
+
+ break;
+ }
+ case InterpolationPolicy::BILINEAR:
+ {
+ const float k = border_undefined ? 0.5f : 0.0f;
+
+ const float valid_start_in_x = src_info.valid_region().anchor[0] + k;
+ const float valid_start_in_y = src_info.valid_region().anchor[1] + k;
+
+ const float valid_end_in_x = src_info.valid_region().anchor[0] + src_info.valid_region().shape[0] - k;
+ const float valid_end_in_y = src_info.valid_region().anchor[1] + src_info.valid_region().shape[1] - k;
+
+ // (start_out + 0.5) / scale >= start_in
+ // --> start_out >= (start_in * scale) - 0.5
+ // --> start_out = ceil((start_in * scale) - 0.5)
+ valid_start_out_x = std::ceil((valid_start_in_x * scale_x) - 0.5f);
+ valid_start_out_y = std::ceil((valid_start_in_y * scale_y) - 0.5f);
+
+ // (end_out - 0.5) / scale <= end_in
+ // --> end_out <= (end_in * scale) + 0.5
+ // --> end_out = floor((end_in * scale) + 0.5)
+ valid_end_out_x = std::floor((valid_end_in_x * scale_x) + 0.5f);
+ valid_end_out_y = std::floor((valid_end_in_y * scale_y) + 0.5f);
+
+ break;
+ }
+ case InterpolationPolicy::AREA:
+ {
+ const float valid_start_in_x = src_info.valid_region().anchor[0];
+ const float valid_start_in_y = src_info.valid_region().anchor[1];
+
+ const float valid_end_in_x = src_info.valid_region().anchor[0] + src_info.valid_region().shape[0];
+ const float valid_end_in_y = src_info.valid_region().anchor[1] + src_info.valid_region().shape[1];
+
+ // start_out / scale >= start_in
+ // --> start_out >= start_in * scale
+ // --> start_out = ceil(start_in * scale)
+ valid_start_out_x = std::ceil(valid_start_in_x * scale_x);
+ valid_start_out_y = std::ceil(valid_start_in_y * scale_y);
+
+ // end_out / scale <= end_in
+ // --> end_out <= end_in * scale
+ // --> end_out = floor(end_in * scale)
+ valid_end_out_x = std::floor(valid_end_in_x * scale_x);
+ valid_end_out_y = std::floor(valid_end_in_y * scale_y);
+
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Invalid InterpolationPolicy");
+ break;
+ }
+ }
ValidRegion valid_region{ Coordinates(), dst_shape, src_info.tensor_shape().num_dimensions() };
- Coordinates &anchor = valid_region.anchor;
- TensorShape &shape = valid_region.shape;
-
- anchor.set(0, (policy == InterpolationPolicy::BILINEAR
- && border_undefined) ?
- ((static_cast<int>(src_info.valid_region().anchor[0]) + border_size.left + 0.5f) * wr - 0.5f) :
- ((static_cast<int>(src_info.valid_region().anchor[0]) + 0.5f) * wr - 0.5f));
- anchor.set(1, (policy == InterpolationPolicy::BILINEAR
- && border_undefined) ?
- ((static_cast<int>(src_info.valid_region().anchor[1]) + border_size.top + 0.5f) * hr - 0.5f) :
- ((static_cast<int>(src_info.valid_region().anchor[1]) + 0.5f) * hr - 0.5f));
- float shape_out_x = (policy == InterpolationPolicy::BILINEAR
- && border_undefined) ?
- ((static_cast<int>(src_info.valid_region().anchor[0]) + static_cast<int>(src_info.valid_region().shape[0]) - 1) - 1 + 0.5f) * wr - 0.5f :
- ((static_cast<int>(src_info.valid_region().anchor[0]) + static_cast<int>(src_info.valid_region().shape[0])) + 0.5f) * wr - 0.5f;
- float shape_out_y = (policy == InterpolationPolicy::BILINEAR
- && border_undefined) ?
- ((static_cast<int>(src_info.valid_region().anchor[1]) + static_cast<int>(src_info.valid_region().shape[1]) - 1) - 1 + 0.5f) * hr - 0.5f :
- ((static_cast<int>(src_info.valid_region().anchor[1]) + static_cast<int>(src_info.valid_region().shape[1])) + 0.5f) * hr - 0.5f;
-
- shape.set(0, shape_out_x - anchor[0]);
- shape.set(1, shape_out_y - anchor[1]);
+ valid_region.anchor.set(0, valid_start_out_x);
+ valid_region.anchor.set(1, valid_start_out_y);
+
+ valid_region.shape.set(0, static_cast<size_t>(valid_end_out_x - valid_start_out_x));
+ valid_region.shape.set(1, static_cast<size_t>(valid_end_out_y - valid_start_out_y));
return valid_region;
}