aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ROIAlignLayerFixture.h
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2018-10-24 17:27:02 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-15 10:13:15 +0000
commit60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c (patch)
treec3bda2f1f34a4a602875ddbe9b814b50365db192 /tests/validation/fixtures/ROIAlignLayerFixture.h
parent0cc37c31a36e7b146cf9640ad69925d7c06b71b4 (diff)
downloadComputeLibrary-60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c.tar.gz
COMPMID-1676: Change CLROIAlign interface to accept ROIs as tensors
Change-Id: I69e995973597ba3927d29e4f6ed5438560e53d77
Diffstat (limited to 'tests/validation/fixtures/ROIAlignLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ROIAlignLayerFixture.h80
1 files changed, 60 insertions, 20 deletions
diff --git a/tests/validation/fixtures/ROIAlignLayerFixture.h b/tests/validation/fixtures/ROIAlignLayerFixture.h
index d327b0914e..c029fbae8a 100644
--- a/tests/validation/fixtures/ROIAlignLayerFixture.h
+++ b/tests/validation/fixtures/ROIAlignLayerFixture.h
@@ -41,18 +41,15 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename Array_T, typename ArrayAccessor, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class ROIAlignLayerFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, unsigned int num_rois, DataType data_type, int batches)
+ void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type)
{
- input_shape.set(2, batches);
- std::vector<ROI> rois = generate_random_rois(input_shape, pool_info, num_rois, 0U);
-
- _target = compute_target(input_shape, data_type, rois, pool_info);
- _reference = compute_reference(input_shape, data_type, rois, pool_info);
+ _target = compute_target(input_shape, data_type, pool_info, rois_shape);
+ _reference = compute_reference(input_shape, data_type, pool_info, rois_shape);
}
protected:
@@ -62,37 +59,78 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
+ template <typename U>
+ void generate_rois(U &&rois, const TensorShape &shape, const ROIPoolingLayerInfo &pool_info, TensorShape rois_shape)
+ {
+ const size_t values_per_roi = rois_shape.x();
+ const size_t num_rois = rois_shape.y();
+
+ std::mt19937 gen(library->seed());
+ T *rois_ptr = static_cast<T *>(rois.data());
+
+ const float pool_width = pool_info.pooled_width();
+ const float pool_height = pool_info.pooled_height();
+ const float roi_scale = pool_info.spatial_scale();
+
+ // Calculate distribution bounds
+ const auto scaled_width = static_cast<T>((shape.x() / roi_scale) / pool_width);
+ const auto scaled_height = static_cast<T>((shape.y() / roi_scale) / pool_height);
+ const auto min_width = static_cast<T>(pool_width / roi_scale);
+ const auto min_height = static_cast<T>(pool_height / roi_scale);
+
+ // Create distributions
+ std::uniform_int_distribution<int> dist_batch(0, shape[3] - 1);
+ std::uniform_int_distribution<> dist_x1(0, scaled_width);
+ std::uniform_int_distribution<> dist_y1(0, scaled_height);
+ std::uniform_int_distribution<> dist_w(min_width, std::max(float(min_width), (pool_width - 2) * scaled_width));
+ std::uniform_int_distribution<> dist_h(min_height, std::max(float(min_height), (pool_height - 2) * scaled_height));
+
+ for(unsigned int pw = 0; pw < num_rois; ++pw)
+ {
+ const auto batch_idx = dist_batch(gen);
+ const auto x1 = dist_x1(gen);
+ const auto y1 = dist_y1(gen);
+ const auto x2 = x1 + dist_w(gen);
+ const auto y2 = y1 + dist_h(gen);
+
+ rois_ptr[values_per_roi * pw] = batch_idx;
+ rois_ptr[values_per_roi * pw + 1] = x1;
+ rois_ptr[values_per_roi * pw + 2] = y1;
+ rois_ptr[values_per_roi * pw + 3] = x2;
+ rois_ptr[values_per_roi * pw + 4] = y2;
+ }
+ }
+
TensorType compute_target(const TensorShape &input_shape,
DataType data_type,
- std::vector<ROI> const &rois,
- const ROIPoolingLayerInfo &pool_info)
+ const ROIPoolingLayerInfo &pool_info,
+ const TensorShape rois_shape)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type);
+ TensorType rois_tensor = create_tensor<TensorType>(rois_shape, data_type);
TensorType dst;
- size_t num_rois = rois.size();
-
- // Create roi arrays
- std::unique_ptr<Array_T> rois_array = arm_compute::support::cpp14::make_unique<Array_T>(num_rois);
- fill_array(ArrayAccessor(*rois_array), rois);
-
// Create and configure function
FunctionType roi_align_layer;
- roi_align_layer.configure(&src, rois_array.get(), &dst, pool_info);
+ roi_align_layer.configure(&src, &rois_tensor, &dst, pool_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(rois_tensor.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
src.allocator()->allocate();
+ rois_tensor.allocator()->allocate();
dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rois_tensor.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
fill(AccessorType(src));
+ generate_rois(AccessorType(rois_tensor), input_shape, pool_info, rois_shape);
// Compute function
roi_align_layer.run();
@@ -102,16 +140,18 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &input_shape,
DataType data_type,
- std::vector<ROI> const &rois,
- const ROIPoolingLayerInfo &pool_info)
+ const ROIPoolingLayerInfo &pool_info,
+ const TensorShape rois_shape)
{
// Create reference tensor
SimpleTensor<T> src{ input_shape, data_type };
+ SimpleTensor<T> rois_tensor{ rois_shape, data_type };
// Fill reference tensor
fill(src);
+ generate_rois(rois_tensor, input_shape, pool_info, rois_shape);
- return reference::roi_align_layer(src, rois, pool_info);
+ return reference::roi_align_layer(src, rois_tensor, pool_info);
}
TensorType _target{};