aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGeorge Wort <george.wort@arm.com>2019-01-08 11:41:54 +0000
committerGeorge Wort <george.wort@arm.com>2019-01-16 15:55:43 +0000
commit44b4e974590f1a6a07b235f203006cc9010b37e8 (patch)
tree1f7f76712847e9b7269bc56f972006dd9902ea3c /tests
parentf63885bc445af1329e6a5c44d94b5c5d78146b2c (diff)
downloadComputeLibrary-44b4e974590f1a6a07b235f203006cc9010b37e8.tar.gz
COMPMID-1794: Add support for NHWC in CLROIAlignLayer
Change-Id: If1df8f6c0549c986e607cbceb0977c80b2891b75 Reviewed-on: https://review.mlplatform.org/493 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/ROIAlignLayer.cpp36
-rw-r--r--tests/validation/fixtures/ROIAlignLayerFixture.h24
2 files changed, 34 insertions, 26 deletions
diff --git a/tests/validation/CL/ROIAlignLayer.cpp b/tests/validation/CL/ROIAlignLayer.cpp
index 926a3de68d..566e1985b3 100644
--- a/tests/validation/CL/ROIAlignLayer.cpp
+++ b/tests/validation/CL/ROIAlignLayer.cpp
@@ -58,26 +58,26 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/rois
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching data type input/output
TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching depth size input/output
- TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching number of rois and output batch size
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching number of rois and output batch size
TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Invalid number of values per ROIS
- TensorInfo(TensorShape(250U, 128U, 2U), 1, DataType::F32), // Mismatching height and width input/output
+ TensorInfo(TensorShape(250U, 128U, 3U), 1, DataType::F32), // Mismatching height and width input/output
}),
- framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(5, 3U), 1, DataType::F16),
- TensorInfo(TensorShape(5, 3U), 1, DataType::F32),
+ framework::dataset::make("RoisInfo", { TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F16),
+ TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 10U), 1, DataType::F32),
- TensorInfo(TensorShape(4, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(4, 4U), 1, DataType::F32),
TensorInfo(TensorShape(5, 4U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F16),
- TensorInfo(TensorShape(7U, 7U, 4U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 2U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(7U, 7U, 3U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(5U, 5U, 2U, 4U), 1, DataType::F32),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F16),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(7U, 7U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 5U, 3U, 4U), 1, DataType::F32),
})),
framework::dataset::make("PoolInfo", { ROIPoolingLayerInfo(7U, 7U, 1./8),
ROIPoolingLayerInfo(7U, 7U, 1./8),
@@ -100,15 +100,17 @@ using CLROIAlignLayerFixture = ROIAlignLayerFixture<CLTensor, CLAccessor, CLROIA
TEST_SUITE(Float)
FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerFloat, CLROIAlignLayerFixture<float>, framework::DatasetMode::ALL,
- framework::dataset::combine(datasets::SmallROIDataset(),
- framework::dataset::make("DataType", { DataType::F32 })))
+ framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(),
+ framework::dataset::make("DataType", { DataType::F32 })),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, relative_tolerance_f32, .02f, absolute_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(SmallROIAlignLayerHalf, CLROIAlignLayerFixture<half>, framework::DatasetMode::ALL,
- framework::dataset::combine(datasets::SmallROIDataset(),
- framework::dataset::make("DataType", { DataType::F16 })))
+ framework::dataset::combine(framework::dataset::combine(datasets::SmallROIDataset(),
+ framework::dataset::make("DataType", { DataType::F16 })),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, relative_tolerance_f16, .02f, absolute_tolerance_f16);
diff --git a/tests/validation/fixtures/ROIAlignLayerFixture.h b/tests/validation/fixtures/ROIAlignLayerFixture.h
index c029fbae8a..dfbb478a41 100644
--- a/tests/validation/fixtures/ROIAlignLayerFixture.h
+++ b/tests/validation/fixtures/ROIAlignLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,9 +46,9 @@ class ROIAlignLayerFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type)
+ void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout)
{
- _target = compute_target(input_shape, data_type, pool_info, rois_shape);
+ _target = compute_target(input_shape, data_type, data_layout, pool_info, rois_shape);
_reference = compute_reference(input_shape, data_type, pool_info, rois_shape);
}
@@ -60,7 +60,7 @@ protected:
}
template <typename U>
- void generate_rois(U &&rois, const TensorShape &shape, const ROIPoolingLayerInfo &pool_info, TensorShape rois_shape)
+ void generate_rois(U &&rois, const TensorShape &shape, const ROIPoolingLayerInfo &pool_info, TensorShape rois_shape, DataLayout data_layout = DataLayout::NCHW)
{
const size_t values_per_roi = rois_shape.x();
const size_t num_rois = rois_shape.y();
@@ -73,8 +73,8 @@ protected:
const float roi_scale = pool_info.spatial_scale();
// Calculate distribution bounds
- const auto scaled_width = static_cast<T>((shape.x() / roi_scale) / pool_width);
- const auto scaled_height = static_cast<T>((shape.y() / roi_scale) / pool_height);
+ const auto scaled_width = static_cast<T>((shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)] / roi_scale) / pool_width);
+ const auto scaled_height = static_cast<T>((shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)] / roi_scale) / pool_height);
const auto min_width = static_cast<T>(pool_width / roi_scale);
const auto min_height = static_cast<T>(pool_height / roi_scale);
@@ -101,13 +101,19 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape,
+ TensorType compute_target(TensorShape input_shape,
DataType data_type,
+ DataLayout data_layout,
const ROIPoolingLayerInfo &pool_info,
const TensorShape rois_shape)
{
+ if(data_layout == DataLayout::NHWC)
+ {
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
+ }
+
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, QuantizationInfo(), data_layout);
TensorType rois_tensor = create_tensor<TensorType>(rois_shape, data_type);
TensorType dst;
@@ -130,7 +136,7 @@ protected:
// Fill tensors
fill(AccessorType(src));
- generate_rois(AccessorType(rois_tensor), input_shape, pool_info, rois_shape);
+ generate_rois(AccessorType(rois_tensor), input_shape, pool_info, rois_shape, data_layout);
// Compute function
roi_align_layer.run();