aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/DequantizationLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/DequantizationLayer.cpp')
-rw-r--r--tests/validation/NEON/DequantizationLayer.cpp39
1 files changed, 31 insertions, 8 deletions
diff --git a/tests/validation/NEON/DequantizationLayer.cpp b/tests/validation/NEON/DequantizationLayer.cpp
index 22d56ab5d8..9bdba7204f 100644
--- a/tests/validation/NEON/DequantizationLayer.cpp
+++ b/tests/validation/NEON/DequantizationLayer.cpp
@@ -44,35 +44,56 @@ namespace
{
/** Tolerance for float operations */
constexpr AbsoluteTolerance<float> tolerance_f32(0.001f);
+
+const auto DequantizationShapes = concat(concat(concat(datasets::Small3DShapes(),
+ datasets::Large3DShapes()),
+ datasets::Small4DShapes()),
+ datasets::Large4DShapes());
+
} // namespace
TEST_SUITE(NEON)
TEST_SUITE(DequantizationLayer)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::Small2DShapes(), datasets::Large2DShapes()), framework::dataset::make("DataType", DataType::U8)), shape, data_type)
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(DequantizationShapes, framework::dataset::make("DataType", DataType::U8)), shape, data_type)
{
+ TensorShape shape_min_max = shape;
+ shape_min_max.set(Window::DimX, 2);
+
+ // Remove Y and Z dimensions and keep the batches
+ shape_min_max.remove_dimension(1);
+ shape_min_max.remove_dimension(1);
+
// Create tensors
- Tensor src = create_tensor<Tensor>(shape, data_type);
- Tensor dst = create_tensor<Tensor>(shape, DataType::F32);
+ Tensor src = create_tensor<Tensor>(shape, data_type);
+ Tensor dst = create_tensor<Tensor>(shape, DataType::F32);
+ Tensor min_max = create_tensor<Tensor>(shape_min_max, DataType::F32);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(min_max.info()->is_resizable(), framework::LogLevel::ERRORS);
// Create and configure function
- float min = 0.f;
- float max = 0.f;
NEDequantizationLayer dequant_layer;
- dequant_layer.configure(&src, &dst, &min, &max);
+ dequant_layer.configure(&src, &dst, &min_max);
// Validate valid region
const ValidRegion valid_region = shape_to_valid_region(shape);
validate(src.info()->valid_region(), valid_region);
validate(dst.info()->valid_region(), valid_region);
+ // Validate valid region of min_max tensor
+ const ValidRegion valid_region_min_max = shape_to_valid_region(shape_min_max);
+ validate(min_max.info()->valid_region(), valid_region_min_max);
+
// Validate padding
const PaddingSize padding = PaddingCalculator(shape.x(), 8).required_padding();
validate(src.info()->padding(), padding);
validate(dst.info()->padding(), padding);
+
+ // Validate padding of min_max tensor
+ const PaddingSize padding_min_max = PaddingCalculator(shape_min_max.x(), 2).required_padding();
+ validate(min_max.info()->padding(), padding_min_max);
}
template <typename T>
@@ -80,12 +101,14 @@ using NEDequantizationLayerFixture = DequantizationValidationFixture<Tensor, Acc
TEST_SUITE(Integer)
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small3DShapes(), datasets::Small4DShapes()),
+ framework::dataset::make("DataType", DataType::U8)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(datasets::Large2DShapes(), framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large3DShapes(), datasets::Large4DShapes()),
+ framework::dataset::make("DataType", DataType::U8)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);