aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/ROIAlignLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/ROIAlignLayer.cpp')
-rw-r--r--tests/validation/CL/ROIAlignLayer.cpp36
1 files changed, 19 insertions, 17 deletions
diff --git a/tests/validation/CL/ROIAlignLayer.cpp b/tests/validation/CL/ROIAlignLayer.cpp
index 926a3de68d..566e1985b3 100644
--- a/tests/validation/CL/ROIAlignLayer.cpp
+++ b/tests/validation/CL/ROIAlignLayer.cpp
@@ -58,26 +58,26 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/rois
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/output
TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching depth size input/output
- TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching number of rois and output batch size
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching number of rois and output batch size
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Invalid number of values per ROIS
- TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching height and width input/output
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching height and width input/output
}),
- framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(5, 3U), 1, DataType::F16),
- TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
+ framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F16),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 10U), 1, DataType::F32),
- TensorInfo(TensorShape(4, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(4, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F16),
- TensorInfo(TensorShape(7U, 7U, 4U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 2U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(5U, 5U, 2U, 4U), 1, DataType::F32),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F16),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 5U, 3U, 4U), 1, DataType::F32),
})),
framework::dataset::make("PoolInfo", { ROIPoolingLayerInfo(7U, 7U, 1./8),
ROIPoolingLayerInfo(7U, 7U, 1./8),
@@ -100,15 +100,17 @@ using CLROIAlignLayerFixture = ROIAlignLayerFixture<CLTensor, CLAccessor, CLROIA
TEST_SUITE(Float)
FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerFloat, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL,
- framework::dataset::combine(datasets::SmallROIDataset(),
- framework::dataset::make("DataType", { DataType::F32 })))
+ framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(),
+ framework::dataset::make("DataType", { DataType::F32 })),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, relative_tolerance_f32, .02f, absolute_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerHalf, CLROIAlignLayerFixture<half>, framework::DatasetMode::ALL,
- framework::dataset::combine(datasets::SmallROIDataset(),
- framework::dataset::make("DataType", { DataType::F16 })))
+ framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(),
+ framework::dataset::make("DataType", { DataType::F16 })),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, relative_tolerance_f16, .02f, absolute_tolerance_f16);