aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ROIAlignLayer.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2018-10-24 17:27:02 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-15 10:13:15 +0000
commit60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c (patch)
treec3bda2f1f34a4a602875ddbe9b814b50365db192 /tests/validation/reference/ROIAlignLayer.cpp
parent0cc37c31a36e7b146cf9640ad69925d7c06b71b4 (diff)
downloadComputeLibrary-60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c.tar.gz
COMPMID-1676: Change CLROIAlign interface to accept ROIs as tensors
Change-Id: I69e995973597ba3927d29e4f6ed5438560e53d77
Diffstat (limited to 'tests/validation/reference/ROIAlignLayer.cpp')
-rw-r--r--tests/validation/reference/ROIAlignLayer.cpp35
1 files changed, 20 insertions, 15 deletions
diff --git a/tests/validation/reference/ROIAlignLayer.cpp b/tests/validation/reference/ROIAlignLayer.cpp
index 68a465d18f..8a76983d44 100644
--- a/tests/validation/reference/ROIAlignLayer.cpp
+++ b/tests/validation/reference/ROIAlignLayer.cpp
@@ -114,30 +114,35 @@ T clamp(T value, T lower, T upper)
}
} // namespace
template <typename T>
-SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info)
+SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &rois, const ROIPoolingLayerInfo &pool_info)
{
- const size_t num_rois = rois.size();
- DataType dst_data_type = src.data_type();
+ const size_t values_per_roi = rois.shape()[0];
+ const size_t num_rois = rois.shape()[1];
+ DataType dst_data_type = src.data_type();
+
+ const auto *rois_ptr = static_cast<const T *>(rois.data());
TensorShape input_shape = src.shape();
TensorShape output_shape(pool_info.pooled_width(), pool_info.pooled_height(), src.shape()[2], num_rois);
SimpleTensor<T> dst(output_shape, dst_data_type);
// Iterate over every pixel of the input image
- for(size_t px = 0; px < pool_info.pooled_width(); px++)
+ for(size_t px = 0; px < pool_info.pooled_width(); ++px)
{
- for(size_t py = 0; py < pool_info.pooled_height(); py++)
+ for(size_t py = 0; py < pool_info.pooled_height(); ++py)
{
- for(size_t pw = 0; pw < num_rois; pw++)
+ for(size_t pw = 0; pw < num_rois; ++pw)
{
- ROI roi = rois[pw];
- const int roi_batch = roi.batch_idx;
+ const unsigned int roi_batch = rois_ptr[values_per_roi * pw];
+ const auto x1 = float(rois_ptr[values_per_roi * pw + 1]);
+ const auto y1 = float(rois_ptr[values_per_roi * pw + 2]);
+ const auto x2 = float(rois_ptr[values_per_roi * pw + 3]);
+ const auto y2 = float(rois_ptr[values_per_roi * pw + 4]);
- const float roi_anchor_x = roi.rect.x * pool_info.spatial_scale();
- const float roi_anchor_y = roi.rect.y * pool_info.spatial_scale();
- const float roi_dims_x = std::max(roi.rect.width * pool_info.spatial_scale(), 1.0f);
- const float roi_dims_y = std::max(roi.rect.height * pool_info.spatial_scale(), 1.0f);
- ;
+ const float roi_anchor_x = x1 * pool_info.spatial_scale();
+ const float roi_anchor_y = y1 * pool_info.spatial_scale();
+ const float roi_dims_x = std::max((x2 - x1) * pool_info.spatial_scale(), 1.0f);
+ const float roi_dims_y = std::max((y2 - y1) * pool_info.spatial_scale(), 1.0f);
float bin_size_x = roi_dims_x / pool_info.pooled_width();
float bin_size_y = roi_dims_y / pool_info.pooled_height();
@@ -178,8 +183,8 @@ SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const std::vector<RO
}
return dst;
}
-template SimpleTensor<float> roi_align_layer(const SimpleTensor<float> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info);
-template SimpleTensor<half> roi_align_layer(const SimpleTensor<half> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info);
+template SimpleTensor<float> roi_align_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &rois, const ROIPoolingLayerInfo &pool_info);
+template SimpleTensor<half> roi_align_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &rois, const ROIPoolingLayerInfo &pool_info);
} // namespace reference
} // namespace validation
} // namespace test