diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ROIAlignLayer.cpp | 35 | ||||
-rw-r--r-- | tests/validation/reference/ROIAlignLayer.h | 2 |
2 files changed, 21 insertions, 16 deletions
diff --git a/tests/validation/reference/ROIAlignLayer.cpp b/tests/validation/reference/ROIAlignLayer.cpp index 68a465d18f..8a76983d44 100644 --- a/tests/validation/reference/ROIAlignLayer.cpp +++ b/tests/validation/reference/ROIAlignLayer.cpp @@ -114,30 +114,35 @@ T clamp(T value, T lower, T upper) } } // namespace template <typename T> -SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info) +SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &rois, const ROIPoolingLayerInfo &pool_info) { - const size_t num_rois = rois.size(); - DataType dst_data_type = src.data_type(); + const size_t values_per_roi = rois.shape()[0]; + const size_t num_rois = rois.shape()[1]; + DataType dst_data_type = src.data_type(); + + const auto *rois_ptr = static_cast<const T *>(rois.data()); TensorShape input_shape = src.shape(); TensorShape output_shape(pool_info.pooled_width(), pool_info.pooled_height(), src.shape()[2], num_rois); SimpleTensor<T> dst(output_shape, dst_data_type); // Iterate over every pixel of the input image - for(size_t px = 0; px < pool_info.pooled_width(); px++) + for(size_t px = 0; px < pool_info.pooled_width(); ++px) { - for(size_t py = 0; py < pool_info.pooled_height(); py++) + for(size_t py = 0; py < pool_info.pooled_height(); ++py) { - for(size_t pw = 0; pw < num_rois; pw++) + for(size_t pw = 0; pw < num_rois; ++pw) { - ROI roi = rois[pw]; - const int roi_batch = roi.batch_idx; + const unsigned int roi_batch = rois_ptr[values_per_roi * pw]; + const auto x1 = float(rois_ptr[values_per_roi * pw + 1]); + const auto y1 = float(rois_ptr[values_per_roi * pw + 2]); + const auto x2 = float(rois_ptr[values_per_roi * pw + 3]); + const auto y2 = float(rois_ptr[values_per_roi * pw + 4]); - const float roi_anchor_x = roi.rect.x * pool_info.spatial_scale(); - const float roi_anchor_y = roi.rect.y * pool_info.spatial_scale(); - const float roi_dims_x = std::max(roi.rect.width * pool_info.spatial_scale(), 1.0f); - const float roi_dims_y = std::max(roi.rect.height * pool_info.spatial_scale(), 1.0f); - ; + const float roi_anchor_x = x1 * pool_info.spatial_scale(); + const float roi_anchor_y = y1 * pool_info.spatial_scale(); + const float roi_dims_x = std::max((x2 - x1) * pool_info.spatial_scale(), 1.0f); + const float roi_dims_y = std::max((y2 - y1) * pool_info.spatial_scale(), 1.0f); float bin_size_x = roi_dims_x / pool_info.pooled_width(); float bin_size_y = roi_dims_y / pool_info.pooled_height(); @@ -178,8 +183,8 @@ SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const std::vector<RO } return dst; } -template SimpleTensor<float> roi_align_layer(const SimpleTensor<float> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info); -template SimpleTensor<half> roi_align_layer(const SimpleTensor<half> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info); +template SimpleTensor<float> roi_align_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &rois, const ROIPoolingLayerInfo &pool_info); +template SimpleTensor<half> roi_align_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &rois, const ROIPoolingLayerInfo &pool_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/ROIAlignLayer.h b/tests/validation/reference/ROIAlignLayer.h index 818f9b147c..b67ff42166 100644 --- a/tests/validation/reference/ROIAlignLayer.h +++ b/tests/validation/reference/ROIAlignLayer.h @@ -37,7 +37,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info); +SimpleTensor<T> roi_align_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &rois, const ROIPoolingLayerInfo &pool_info); } // namespace reference } // namespace validation } // namespace test |