aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2017-07-04 11:21:28 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:17:31 +0100
commit6203153b25fda2c4b8d94fe71337d1145a36c132 (patch)
treeaee1e7db73d09eabf01ed8fb7680f98f97e07c93 /tests/validation/TensorOperations.h
parentd6253602ec7ebcbe6adb0acc51a97a5d8c167495 (diff)
downloadComputeLibrary-6203153b25fda2c4b8d94fe71337d1145a36c132.tar.gz
COMPMID-424 Implemented reference implementation and validation tests (NEON and CL) for Warp Perspective
Changed the behaviour in NEWarpKernel for border mode replicate and constant to stick with the VX specs. When the new coords are out of the valid region, the output will be computed using the values from the border. In the validation tests the validate will be called with tolerance_value 1 and tolerance_number 0.2%, due to some float arithmetic related mismatches. Change-Id: Id4f9d0ef87178f8f8fd38ee17fee0e6f4beb85cd Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80283 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com> Reviewed-by: Steven Niu <steven.niu@arm.com>
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h128
1 files changed, 127 insertions, 1 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 317d22934e..3220d80a04 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -388,7 +388,7 @@ void accumulate_squared(const Tensor<T1> &in, Tensor<T2> &out, uint32_t shift)
}
}
-// Accumulate weighted
+// Accumulate weighted total_size = init_auto_padding(tensor_shape, num_channels, type);
template <typename T>
void accumulate_weighted(const Tensor<T> &in, Tensor<T> &out, float alpha)
{
@@ -748,6 +748,132 @@ void threshold(const Tensor<T> &in, Tensor<T> &out, uint8_t threshold, uint8_t f
}
}
+template <typename T>
+T bilinear_policy(const Tensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value)
+{
+ int idx = std::floor(xn);
+ int idy = std::floor(yn);
+
+ const float dx = xn - idx;
+ const float dy = yn - idy;
+ const float dx_1 = 1.0f - dx;
+ const float dy_1 = 1.0f - dy;
+
+ id.set(0, idx);
+ id.set(1, idy);
+ const T tl = tensor_elem_at(in, id, border_mode, constant_border_value);
+ id.set(0, idx + 1);
+ id.set(1, idy);
+ const T tr = tensor_elem_at(in, id, border_mode, constant_border_value);
+ id.set(0, idx);
+ id.set(1, idy + 1);
+ const T bl = tensor_elem_at(in, id, border_mode, constant_border_value);
+ id.set(0, idx + 1);
+ id.set(1, idy + 1);
+ const T br = tensor_elem_at(in, id, border_mode, constant_border_value);
+
+ return tl * (dx_1 * dy_1) + tr * (dx * dy_1) + bl * (dx_1 * dy) + br * (dx * dy);
+}
+
+bool valid_bilinear_policy(float xn, float yn, int width, int height, BorderMode border_mode)
+{
+ if(border_mode != BorderMode::UNDEFINED)
+ {
+ return true;
+ }
+ if((0 <= yn + 1) && (yn + 1 < height) && (0 <= xn + 1) && (xn + 1 < width))
+ {
+ return true;
+ }
+ return false;
+}
+
+// Warp Perspective
+template <typename T>
+void warp_perspective(const Tensor<T> &in, Tensor<T> &out, Tensor<T> &valid_mask, const float *matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+{
+ // x0 = M00 * x + M01 * y + M02
+ // y0 = M10 * x + M11 * y + M12
+ // z0 = M20 * x + M21 * y + M22
+ // xn = x0 / z0
+ // yn = y0 / z0
+ const float M00 = matrix[0];
+ const float M10 = matrix[1];
+ const float M20 = matrix[2];
+ const float M01 = matrix[0 + 1 * 3];
+ const float M11 = matrix[1 + 1 * 3];
+ const float M21 = matrix[2 + 1 * 3];
+ const float M02 = matrix[0 + 2 * 3];
+ const float M12 = matrix[1 + 2 * 3];
+ const float M22 = matrix[2 + 2 * 3];
+
+ const int width = in.shape().x();
+ const int height = in.shape().y();
+
+ for(int element_idx = 0; element_idx < in.num_elements(); ++element_idx)
+ {
+ valid_mask[element_idx] = 1;
+ Coordinates id = index2coord(in.shape(), element_idx);
+ int idx = id.x();
+ int idy = id.y();
+ const float z0 = M20 * idx + M21 * idy + M22;
+
+ float x0 = (M00 * idx + M01 * idy + M02);
+ float y0 = (M10 * idx + M11 * idy + M12);
+
+ float xn = x0 / z0;
+ float yn = y0 / z0;
+ id.set(0, static_cast<int>(std::floor(xn)));
+ id.set(1, static_cast<int>(std::floor(yn)));
+ if((0 <= yn) && (yn < height) && (0 <= xn) && (xn < width))
+ {
+ switch(policy)
+ {
+ case InterpolationPolicy::NEAREST_NEIGHBOR:
+ out[element_idx] = tensor_elem_at(in, id, border_mode, constant_border_value);
+ break;
+ case InterpolationPolicy::BILINEAR:
+ (valid_bilinear_policy(xn, yn, width, height, border_mode)) ? out[element_idx] = bilinear_policy(in, id, xn, yn, border_mode, constant_border_value) : valid_mask[element_idx] = 0;
+ break;
+ case InterpolationPolicy::AREA:
+ default:
+ ARM_COMPUTE_ERROR("Interpolation not supported");
+ }
+ }
+ else
+ {
+ if(border_mode == BorderMode::UNDEFINED)
+ {
+ valid_mask[element_idx] = 0;
+ }
+ else
+ {
+ switch(policy)
+ {
+ case InterpolationPolicy::NEAREST_NEIGHBOR:
+ if(border_mode == BorderMode::CONSTANT)
+ {
+ out[element_idx] = constant_border_value;
+ }
+ else if(border_mode == BorderMode::REPLICATE)
+ {
+ id.set(0, std::max(0, std::min(static_cast<int>(xn), width - 1)));
+ id.set(1, std::max(0, std::min(static_cast<int>(yn), height - 1)));
+ out[element_idx] = in[coord2index(in.shape(), id)];
+ }
+ break;
+ case InterpolationPolicy::BILINEAR:
+ out[element_idx] = bilinear_policy(in, id, xn, yn, border_mode, constant_border_value);
+ break;
+ case InterpolationPolicy::AREA:
+ default:
+ ARM_COMPUTE_ERROR("Interpolation not supported");
+ }
+ }
+ }
+ }
+}
+
// Batch Normalization Layer for fixed point type
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type * = nullptr>
void batch_normalization_layer(const Tensor<T> &in, Tensor<T> &out, const Tensor<T> &mean, const Tensor<T> &var, const Tensor<T> &beta, const Tensor<T> &gamma, float epsilon, int fixed_point_position)