aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/Remap.cpp69
-rw-r--r--tests/validation/fixtures/RemapFixture.h40
-rw-r--r--tests/validation/reference/Remap.cpp5
3 files changed, 101 insertions, 13 deletions
diff --git a/tests/validation/CL/Remap.cpp b/tests/validation/CL/Remap.cpp
index bbb3cecea9..7849d77394 100644
--- a/tests/validation/CL/Remap.cpp
+++ b/tests/validation/CL/Remap.cpp
@@ -48,15 +48,64 @@ constexpr AbsoluteTolerance<uint8_t> tolerance_value(1);
TEST_SUITE(CL)
TEST_SUITE(Remap)
+
+// *INDENT-OFF*
+// clang-format off
+
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("input", { TensorInfo(TensorShape(10U, 10U), 1, DataType::U8, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::U8, DataLayout::NHWC),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F16, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F16, DataLayout::NHWC)
+ }),
+ framework::dataset::make("map_x", { TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NHWC)
+
+ })),
+ framework::dataset::make("map_y", { TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F32, DataLayout::NHWC)
+ })),
+ framework::dataset::make("output", { TensorInfo(TensorShape(10U, 10U), 1, DataType::U8, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::U8, DataLayout::NHWC),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F16, DataLayout::NCHW),
+ TensorInfo(TensorShape(10U, 10U), 1, DataType::F16, DataLayout::NHWC)
+ })),
+ framework::dataset::make("policy",{ InterpolationPolicy::NEAREST_NEIGHBOR,
+ InterpolationPolicy::NEAREST_NEIGHBOR,
+ InterpolationPolicy::NEAREST_NEIGHBOR,
+ InterpolationPolicy::NEAREST_NEIGHBOR
+ })),
+ framework::dataset::make("border_mode",{ BorderMode::CONSTANT,
+ BorderMode::CONSTANT,
+ BorderMode::CONSTANT,
+ BorderMode::CONSTANT
+ })),
+ framework::dataset::make("Expected", { true, // NCHW, U8
+ true, // NHWC, U8
+ false, // NCHW, F16
+ true // NHWC, F16
+ })),
+ input, map_x, map_y, output, policy, border_mode, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(CLRemap::validate(&input, &map_x, &map_y, &output, policy, border_mode, PixelValue{})) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
template <typename T>
using CLRemapFixture = RemapValidationFixture<CLTensor, CLAccessor, CLRemap, T>;
template <typename T>
using CLRemapLayoutFixture = RemapValidationMixedLayoutFixture<CLTensor, CLAccessor, CLRemap, T>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CLRemapLayoutFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
- framework::dataset::make("DataType", DataType::U8)),
- framework::dataset::make("BorderModes", { BorderMode::UNDEFINED, BorderMode::CONSTANT })),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+TEST_SUITE(U8)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRemapLayoutFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
+ framework::dataset::make("DataType", DataType::U8)),
+ framework::dataset::make("BorderModes", { BorderMode::UNDEFINED, BorderMode::CONSTANT })),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, _valid_mask, tolerance_value);
@@ -69,7 +118,19 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLRemapFixture<uint8_t>, framework::DatasetMode
// Validate output
validate(CLAccessor(_target), _reference, _valid_mask, tolerance_value);
}
+TEST_SUITE_END() // U8
+TEST_SUITE(F16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRemapLayoutFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("BorderModes", { BorderMode::UNDEFINED, BorderMode::CONSTANT })),
+ framework::dataset::make("DataLayout", DataLayout::NHWC)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, _valid_mask, tolerance_value);
+}
+TEST_SUITE_END() // F16
TEST_SUITE_END()
TEST_SUITE_END()
} // namespace validation
diff --git a/tests/validation/fixtures/RemapFixture.h b/tests/validation/fixtures/RemapFixture.h
index 2cb8e67f62..03cb6aef42 100644
--- a/tests/validation/fixtures/RemapFixture.h
+++ b/tests/validation/fixtures/RemapFixture.h
@@ -50,7 +50,7 @@ public:
{
std::mt19937 gen(library->seed());
std::uniform_int_distribution<uint8_t> distribution(0, 255);
- const T constant_border_value = static_cast<T>(distribution(gen));
+ PixelValue constant_border_value{ static_cast<T>(distribution(gen)) };
_data_layout = data_layout;
_target = compute_target(shape, policy, data_type, border_mode, constant_border_value);
@@ -59,13 +59,35 @@ public:
protected:
template <typename U>
- void fill(U &&tensor, int i, float min, float max)
+ void fill(U &&tensor, int i, int min, int max)
{
- std::uniform_int_distribution<> distribution((int)min, (int)max);
- library->fill(tensor, distribution, i);
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ // map_x,y as integer values
+ std::uniform_int_distribution<int> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::F16:
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution(static_cast<float>(min), static_cast<float>(max));
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::U8:
+ {
+ std::uniform_int_distribution<uint8_t> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("DataType for Remap not supported");
+ }
}
- TensorType compute_target(TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, T constant_border_value)
+ TensorType compute_target(TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, PixelValue constant_border_value)
{
if(_data_layout == DataLayout::NHWC)
{
@@ -111,14 +133,16 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, T constant_border_value)
+ SimpleTensor<T> compute_reference(const TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, PixelValue constant_border_value)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::U8);
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::U8 && data_type != DataType::F16);
// Create reference
SimpleTensor<T> src{ shape, data_type };
SimpleTensor<float> map_x{ shape, DataType::F32 };
SimpleTensor<float> map_y{ shape, DataType::F32 };
+ T border_value{};
+ constant_border_value.get(border_value);
// Create the valid mask Tensor
_valid_mask = SimpleTensor<T> { shape, data_type };
@@ -131,7 +155,7 @@ protected:
fill(map_y, 2, -5, max_val);
// Compute reference
- return reference::remap<T>(src, map_x, map_y, _valid_mask, policy, border_mode, constant_border_value);
+ return reference::remap<T>(src, map_x, map_y, _valid_mask, policy, border_mode, border_value);
}
TensorType _target{};
diff --git a/tests/validation/reference/Remap.cpp b/tests/validation/reference/Remap.cpp
index 33c5a7de68..dfbe1aa12b 100644
--- a/tests/validation/reference/Remap.cpp
+++ b/tests/validation/reference/Remap.cpp
@@ -99,13 +99,16 @@ SimpleTensor<T> remap(const SimpleTensor<T> &in, SimpleTensor<float> &map_x, Sim
}
}
}
-
return out;
}
template SimpleTensor<uint8_t> remap(const SimpleTensor<uint8_t> &src, SimpleTensor<float> &map_x, SimpleTensor<float> &map_y, SimpleTensor<uint8_t> &valid_mask, InterpolationPolicy policy,
BorderMode border_mode,
uint8_t constant_border_value);
+
+template SimpleTensor<half> remap(const SimpleTensor<half> &src, SimpleTensor<float> &map_x, SimpleTensor<float> &map_y, SimpleTensor<half> &valid_mask, InterpolationPolicy policy,
+ BorderMode border_mode,
+ half constant_border_value);
} // namespace reference
} // namespace validation
} // namespace test