diff options
Diffstat (limited to 'tests/validation/NEON')
-rw-r--r-- | tests/validation/NEON/QuantizationLayer.cpp | 72 |
1 files changed, 65 insertions, 7 deletions
diff --git a/tests/validation/NEON/QuantizationLayer.cpp b/tests/validation/NEON/QuantizationLayer.cpp index a4af2a2886..a5372b897c 100644 --- a/tests/validation/NEON/QuantizationLayer.cpp +++ b/tests/validation/NEON/QuantizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -43,11 +43,11 @@ namespace validation namespace { /** Tolerance for quantization */ -constexpr AbsoluteTolerance<uint8_t> tolerance_u8(1); -constexpr AbsoluteTolerance<uint16_t> tolerance_u16(1); - -const auto QuantizationSmallShapes = concat(datasets::Small3DShapes(), datasets::Small4DShapes()); -const auto QuantizationLargeShapes = concat(datasets::Large3DShapes(), datasets::Large4DShapes()); +constexpr AbsoluteTolerance<uint8_t> tolerance_u8(1); /**< Tolerance value for comparing reference's output against implementation's output for QASYMM8 data types */ +constexpr AbsoluteTolerance<int8_t> tolerance_s8(1); /**< Tolerance value for comparing reference's output against implementation's output for QASYMM8_SIGNED data types */ +constexpr AbsoluteTolerance<uint16_t> tolerance_u16(1); /**< Tolerance value for comparing reference's output against implementation's output for QASYMM16 data types */ +const auto QuantizationSmallShapes = concat(datasets::Small3DShapes(), datasets::Small4DShapes()); +const auto QuantizationLargeShapes = concat(datasets::Large3DShapes(), datasets::Large4DShapes()); } // namespace TEST_SUITE(NEON) @@ -56,7 +56,7 @@ TEST_SUITE(QuantizationLayer) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( - framework::dataset::make("InputInfo", { TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::QASYMM8), // Wrong input data type + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::QASYMM8), // Wrong output data type TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32), // Wrong output data type TensorInfo(TensorShape(16U, 16U, 2U, 5U), 1, DataType::F32), // Missmatching shapes TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32), // Valid @@ -193,6 +193,64 @@ TEST_SUITE_END() // FP16 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE_END() // Float +TEST_SUITE(Quantized) +template <typename T> +using NEQuantizationLayerQASYMM8GenFixture = QuantizationValidationGenericFixture<Tensor, Accessor, NEQuantizationLayer, T, uint8_t>; +template <typename T> +using NEQuantizationLayerQASYMM8_SIGNEDGenFixture = QuantizationValidationGenericFixture<Tensor, Accessor, NEQuantizationLayer, T, int8_t>; +template <typename T> +using NEQuantizationLayerQASYMM16GenFixture = QuantizationValidationGenericFixture<Tensor, Accessor, NEQuantizationLayer, T, uint16_t>; +TEST_SUITE(QASYMM8) +FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8, NEQuantizationLayerQASYMM8GenFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes, + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", { DataType::QASYMM8 })), + framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(0.5f, 10) })), + framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, 15) }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_u8); +} +FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8_SIGNED, NEQuantizationLayerQASYMM8_SIGNEDGenFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes, + framework::dataset::make("DataTypeIn", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", { DataType::QASYMM8_SIGNED })), + framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 10), QuantizationInfo(2.0f, -25) })), + framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 15) }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_s8); +} +FIXTURE_DATA_TEST_CASE(RunSmallQASYMM16, NEQuantizationLayerQASYMM16GenFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes, + framework::dataset::make("DataTypeIn", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", { DataType::QASYMM16 })), + framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 10) })), + framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(4.0f, 23) }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_u16); +} +TEST_SUITE_END() // QASYMM8 +TEST_SUITE(QASYMM8_SIGNED) +FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8_SIGNED, NEQuantizationLayerQASYMM8_SIGNEDGenFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes, + framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeOut", { DataType::QASYMM8_SIGNED })), + framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 10) })), + framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, -5) }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_s8); +} +FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8, NEQuantizationLayerQASYMM8GenFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes, + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeOut", { DataType::QASYMM8 })), + framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(2.0f, 10), QuantizationInfo(2.0f, -25) })), + framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 30) }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_u8); +} +TEST_SUITE_END() // QASYMM8_SIGNED +TEST_SUITE_END() // Quantized + TEST_SUITE_END() // QuantizationLayer TEST_SUITE_END() // NEON } // namespace validation |