From b178b282e27e731dddb4f5a950fb5694edc4514a Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Tue, 4 Jan 2022 13:37:53 +0000 Subject: MLECO-2834: Tests for Softmax function. Added tests for recently added Softmax function in PlatformMath module. Change-Id: Iacf1f4eaf33a92e1d42275000765e7152d17176b --- tests/common/PlatformMathTests.cpp | 75 +++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp index 02653f2..2155886 100644 --- a/tests/common/PlatformMathTests.cpp +++ b/tests/common/PlatformMathTests.cpp @@ -597,6 +597,79 @@ TEST_CASE("Test ComplexMagnitudeSquaredF32") * of the input vector with output results */ arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen); - for (size_t i = 0; i < outputLen; i++) + for (size_t i = 0; i < outputLen; i++) { CHECK (expectedResult[i] == Approx(output[i])); + } +} + +/** + * @brief Simple function to test the Softmax function + * + * @param input Input vector + * @param goldenOutput Expected output vector + */ +static void TestSoftmaxF32( + const std::vector& input, + const std::vector& goldenOutput) +{ + std::vector output = input; /* Function modifies the vector in-place */ + arm::app::math::MathUtils::SoftmaxF32(output); + + for (size_t i = 0; i < goldenOutput.size(); ++i) { + CHECK(goldenOutput[i] == Approx(output[i])); + } + + REQUIRE(output.size() == goldenOutput.size()); +} + +TEST_CASE("Test SoftmaxF32") +{ + SECTION("Simple series") { + const std::vector input { + 0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0 + }; + + const std::vector expectedOutput { + 7.80134161e-05, 2.12062451e-04, + 5.76445508e-04, 1.56694135e-03, + 4.25938820e-03, 1.15782175e-02, + 3.14728583e-02, 8.55520989e-02, + 2.32554716e-01, 6.32149258e-01 + }; + + TestSoftmaxF32(input, expectedOutput); + } + + SECTION("Random series") { + const std::vector input { + 0.8810943246170809, 0.5877587675947015, + 0.6841546454788743, 0.4155920960071594, + 0.9799415323651671, 0.5066432973545711, + 0.3846024252355448, 0.4568689569632123, + 0.3284413744557605, 0.49152323726213554 + }; + + const std::vector expectedOutput { + 0.13329595, 0.09940837, + 0.10946799, 0.08368583, + 0.14714509, 0.09166319, + 0.08113220, 0.08721240, + 0.07670132, 0.09028766 + }; + + TestSoftmaxF32(input, expectedOutput); + } + + SECTION("Series with large STD") { + const std::vector input { + 0.001, 1000.000 + }; + + const std::vector expectedOutput { + 0.000, 1.000 + }; + + TestSoftmaxF32(input, expectedOutput); + } } -- cgit v1.2.1