aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/SoftmaxLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/SoftmaxLayer.cpp')
-rw-r--r--tests/validation/NEON/SoftmaxLayer.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index c429782e60..8af3847cf8 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -97,9 +97,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
framework::dataset::make("axis", { 1,
1,
1,
+ -1,
1,
- 1,
- 0,
+ -3,
})),
framework::dataset::make("Expected", { false, false, false, true, true, false })),
input_info, output_info, beta, axis, expected)
@@ -188,7 +188,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerQuantizedFixture<uint8_t>, fram
framework::dataset::make("DataType", DataType::QASYMM8)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1, 2, 3 })))
+ framework::dataset::make("Axis", { -1, 2, 3 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -209,7 +209,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall2D, NESoftmaxLayerQuantizedFixture<int8_t>, frame
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1 })))
+ framework::dataset::make("Axis", { -1, 1 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
@@ -218,7 +218,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerQuantizedFixture<int8_t>, frame
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1, 2, 3 })))
+ framework::dataset::make("Axis", { -2, 2, 3 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);