aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/ROIAlignLayer.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2018-10-24 17:27:02 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-15 10:13:15 +0000
commit60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c (patch)
treec3bda2f1f34a4a602875ddbe9b814b50365db192 /tests/validation/CL/ROIAlignLayer.cpp
parent0cc37c31a36e7b146cf9640ad69925d7c06b71b4 (diff)
downloadComputeLibrary-60f0a41c45813fa9c85cd4f8fbed57c4c9284a5c.tar.gz
COMPMID-1676: Change CLROIAlign interface to accept ROIs as tensors
Change-Id: I69e995973597ba3927d29e4f6ed5438560e53d77
Diffstat (limited to 'tests/validation/CL/ROIAlignLayer.cpp')
-rw-r--r--tests/validation/CL/ROIAlignLayer.cpp48
1 files changed, 33 insertions, 15 deletions
diff --git a/tests/validation/CL/ROIAlignLayer.cpp b/tests/validation/CL/ROIAlignLayer.cpp
index acea6d447c..f3fc3818f2 100644
--- a/tests/validation/CL/ROIAlignLayer.cpp
+++ b/tests/validation/CL/ROIAlignLayer.cpp
@@ -24,9 +24,8 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/CL/functions/CLROIAlignLayer.h"
#include "tests/CL/CLAccessor.h"
-#include "tests/CL/CLArrayAccessor.h"
#include "tests/Globals.h"
-#include "tests/datasets/ROIPoolingLayerDataset.h"
+#include "tests/datasets/ROIAlignLayerDataset.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Macros.h"
#include "tests/framework/datasets/Datasets.h"
@@ -43,7 +42,10 @@ namespace validation
namespace
{
RelativeTolerance<float> relative_tolerance_f32(0.01f);
-RelativeTolerance<float> absolute_tolerance_f32(0.001f);
+AbsoluteTolerance<float> absolute_tolerance_f32(0.001f);
+
+RelativeTolerance<float> relative_tolerance_f16(0.01f);
+AbsoluteTolerance<float> absolute_tolerance_f16(0.001f);
} // namespace
TEST_SUITE(CL)
@@ -53,17 +55,28 @@ TEST_SUITE(RoiAlign)
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
framework::dataset::make("InputInfo", { TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/rois
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/output
TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching depth size input/output
TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching number of rois and output batch size
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Invalid number of values per ROIS
TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching height and width input/output
}),
- framework::dataset::make("NumRois", { 3U, 3U, 4U, 10U, 4U})),
+ framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 3U), 1, DataType::F16),
+ TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 10U), 1, DataType::F32),
+ TensorInfo(TensorShape(4, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
+ })),
framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F16),
TensorInfo(TensorShape(7U, 7U, 4U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(7U, 7U, 2U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 5U, 2U, 4U), 1, DataType::F32),
})),
framework::dataset::make("PoolInfo", { ROIPoolingLayerInfo(7U, 7U, 1./8),
@@ -71,30 +84,35 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
ROIPoolingLayerInfo(7U, 7U, 1./8),
ROIPoolingLayerInfo(7U, 7U, 1./8),
ROIPoolingLayerInfo(7U, 7U, 1./8),
+ ROIPoolingLayerInfo(7U, 7U, 1./8),
+ ROIPoolingLayerInfo(7U, 7U, 1./8),
})),
- framework::dataset::make("Expected", { true, false, false, false, false })),
- input_info, num_rois, output_info, pool_info, expected)
+ framework::dataset::make("Expected", { true, false, false, false, false, false, false })),
+ input_info, rois_info, output_info, pool_info, expected)
{
- ARM_COMPUTE_EXPECT(bool(CLROIAlignLayer::validate(&input_info.clone()->set_is_resizable(true), num_rois, &output_info.clone()->set_is_resizable(true), pool_info)) == expected, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bool(CLROIAlignLayer::validate(&input_info.clone()->set_is_resizable(true), &rois_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), pool_info)) == expected, framework::LogLevel::ERRORS);
}
// clang-format on
// *INDENT-ON*
template <typename T>
-using CLROIAlignLayerFixture = ROIAlignLayerFixture<CLTensor, CLAccessor, CLROIAlignLayer, CLArray<ROI>, CLArrayAccessor<ROI>, T>;
+using CLROIAlignLayerFixture = ROIAlignLayerFixture<CLTensor, CLAccessor, CLROIAlignLayer, T>;
TEST_SUITE(Float)
-TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(SmallROIAlignLayer, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL,
- framework::dataset::combine(framework::dataset::combine(datasets::SmallROIPoolingLayerDataset(),
- framework::dataset::make("DataType", { DataType::F32 })),
- framework::dataset::make("Batches", { 1, 4, 8 })))
+FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerFloat, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL,
+ framework::dataset::combine(datasets::SmallROIAlignLayerDataset(),
+ framework::dataset::make("DataType", { DataType::F32 })))
{
// Validate output
validate(CLAccessor(_target), _reference, relative_tolerance_f32, .02f, absolute_tolerance_f32);
}
-TEST_SUITE_END() // FP32
-
+FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerHalf, CLROIAlignLayerFixture<half>, framework::DatasetMode::ALL,
+ framework::dataset::combine(datasets::SmallROIAlignLayerDataset(),
+ framework::dataset::make("DataType", { DataType::F16 })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, relative_tolerance_f16, .02f, absolute_tolerance_f16);
+}
TEST_SUITE_END() // Float
TEST_SUITE_END() // RoiAlign