diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-08-23 11:49:04 +0100 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-08-29 09:17:46 +0000 |
commit | 578a9fc6c06ebbd6e2650372029e339a4cbcacca (patch) | |
tree | 7aa428d18323cce34dabe088063d2cf533c36fac /tests/validation/CL/ROIAlignLayer.cpp | |
parent | c7ec194d5ae5ee5c9af3e6aa3de43e30382d8f87 (diff) | |
download | ComputeLibrary-578a9fc6c06ebbd6e2650372029e339a4cbcacca.tar.gz |
COMPMID-2317: Implement CLROIAlignLayer
Change-Id: Iaa61b7a3528d3f82339d2ff8a2d77e77a1c68603
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1821
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/ROIAlignLayer.cpp')
-rw-r--r-- | tests/validation/CL/ROIAlignLayer.cpp | 70 |
1 files changed, 50 insertions, 20 deletions
diff --git a/tests/validation/CL/ROIAlignLayer.cpp b/tests/validation/CL/ROIAlignLayer.cpp index 566e1985b3..b213c6815f 100644 --- a/tests/validation/CL/ROIAlignLayer.cpp +++ b/tests/validation/CL/ROIAlignLayer.cpp @@ -41,11 +41,13 @@ namespace validation { namespace { -RelativeTolerance<float> relative_tolerance_f32(0.01f); -AbsoluteTolerance<float> absolute_tolerance_f32(0.001f); +constexpr RelativeTolerance<float> relative_tolerance_f32(0.01f); +constexpr AbsoluteTolerance<float> absolute_tolerance_f32(0.001f); -RelativeTolerance<float> relative_tolerance_f16(0.01f); -AbsoluteTolerance<float> absolute_tolerance_f16(0.001f); +constexpr RelativeTolerance<float> relative_tolerance_f16(0.01f); +constexpr AbsoluteTolerance<float> absolute_tolerance_f16(0.001f); + +constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); } // namespace TEST_SUITE(CL) @@ -55,13 +57,14 @@ TEST_SUITE(RoiAlign) // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), - TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/rois - TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/output - TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching depth size input/output - TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching number of rois and output batch size - TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Invalid number of values per ROIS - TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching height and width input/output - + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/rois + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/output + TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching depth size input/output + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching number of rois and output batch size + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Invalid number of values per ROIS + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching height and width input/output + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::QASYMM8, QuantizationInfo(1.f / 255.f, 127)), // Invalid ROIS data type + TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::QASYMM8, QuantizationInfo(1.f / 255.f, 127)), // Invalid ROIS Quantization Info }), framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 4U), 1, DataType::F32), TensorInfo(TensorShape(5, 4U), 1, DataType::F16), @@ -70,6 +73,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( TensorInfo(TensorShape(5, 10U), 1, DataType::F32), TensorInfo(TensorShape(4, 4U), 1, DataType::F32), TensorInfo(TensorShape(5, 4U), 1, DataType::F32), + TensorInfo(TensorShape(5, 4U), 1, DataType::F32), + TensorInfo(TensorShape(5, 4U), 1, DataType::QASYMM16, QuantizationInfo(0.2f, 0)), })), framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32), TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32), @@ -78,6 +83,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32), TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32), TensorInfo(TensorShape(5U, 5U, 3U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::QASYMM8, QuantizationInfo(1.f / 255.f, 120)), + TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::QASYMM8, QuantizationInfo(1.f / 255.f, 120)), })), framework::dataset::make("PoolInfo", { ROIPoolingLayerInfo(7U, 7U, 1./8), ROIPoolingLayerInfo(7U, 7U, 1./8), @@ -86,8 +93,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( ROIPoolingLayerInfo(7U, 7U, 1./8), ROIPoolingLayerInfo(7U, 7U, 1./8), ROIPoolingLayerInfo(7U, 7U, 1./8), + ROIPoolingLayerInfo(7U, 7U, 1./8), })), - framework::dataset::make("Expected", { true, false, false, false, false, false, false })), + framework::dataset::make("Expected", { true, false, false, false, false, false, false, false, false })), input_info, rois_info, output_info, pool_info, expected) { ARM_COMPUTE_EXPECT(bool(CLROIAlignLayer::validate(&input_info.clone()->set_is_resizable(true), &rois_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), pool_info)) == expected, framework::LogLevel::ERRORS); @@ -99,24 +107,46 @@ template <typename T> using CLROIAlignLayerFixture = ROIAlignLayerFixture<CLTensor, CLAccessor, CLROIAlignLayer, T>; TEST_SUITE(Float) -FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerFloat, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL, - framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(Small, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL, + combine(combine(datasets::SmallROIDataset(), + framework::dataset::make("DataType", { DataType::F32 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output validate(CLAccessor(_target), _reference, relative_tolerance_f32, .02f, absolute_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerHalf, CLROIAlignLayerFixture<half>, framework::DatasetMode::ALL, - framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(), - framework::dataset::make("DataType", { DataType::F16 })), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) +TEST_SUITE_END() // FP32 +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(Small, CLROIAlignLayerFixture<half>, framework::DatasetMode::ALL, + combine(combine(datasets::SmallROIDataset(), + framework::dataset::make("DataType", { DataType::F16 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output validate(CLAccessor(_target), _reference, relative_tolerance_f16, .02f, absolute_tolerance_f16); } +TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float +template <typename T> +using CLROIAlignLayerQuantizedFixture = ROIAlignLayerQuantizedFixture<CLTensor, CLAccessor, CLROIAlignLayer, T>; + +TEST_SUITE(Quantized) +TEST_SUITE(QASYMM8) +FIXTURE_DATA_TEST_CASE(Small, CLROIAlignLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, + combine(combine(combine(combine(datasets::SmallROIDataset(), + framework::dataset::make("DataType", { DataType::QASYMM8 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("InputQuantizationInfo", { QuantizationInfo(1.f / 255.f, 127) })), + framework::dataset::make("OutputQuantizationInfo", { QuantizationInfo(2.f / 255.f, 120) }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QASYMM8 +TEST_SUITE_END() // Quantized + TEST_SUITE_END() // RoiAlign TEST_SUITE_END() // CL } // namespace validation |