diff options
author | Kshitij Sisodia <kshitij.sisodia@arm.com> | 2022-01-04 13:37:53 +0000 |
---|---|---|
committer | Kshitij Sisodia <kshitij.sisodia@arm.com> | 2022-01-04 13:37:53 +0000 |
commit | b178b282e27e731dddb4f5a950fb5694edc4514a (patch) | |
tree | 01847ce2c3479146afdbc12c7d5bba9782cdc49d /tests | |
parent | b5b32d3e6188cc7126b3181e3be328d6583c5967 (diff) | |
download | ml-embedded-evaluation-kit-b178b282e27e731dddb4f5a950fb5694edc4514a.tar.gz |
MLECO-2834: Tests for Softmax function.
Added tests for recently added Softmax function in
PlatformMath module.
Change-Id: Iacf1f4eaf33a92e1d42275000765e7152d17176b
Diffstat (limited to 'tests')
-rw-r--r-- | tests/common/PlatformMathTests.cpp | 75 |
1 files changed, 74 insertions, 1 deletions
diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp index 02653f2..2155886 100644 --- a/tests/common/PlatformMathTests.cpp +++ b/tests/common/PlatformMathTests.cpp @@ -597,6 +597,79 @@ TEST_CASE("Test ComplexMagnitudeSquaredF32") * of the input vector with output results */ arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen); - for (size_t i = 0; i < outputLen; i++) + for (size_t i = 0; i < outputLen; i++) { CHECK (expectedResult[i] == Approx(output[i])); + } +} + +/** + * @brief Simple function to test the Softmax function + * + * @param input Input vector + * @param goldenOutput Expected output vector + */ +static void TestSoftmaxF32( + const std::vector<float>& input, + const std::vector<float>& goldenOutput) +{ + std::vector<float> output = input; /* Function modifies the vector in-place */ + arm::app::math::MathUtils::SoftmaxF32(output); + + for (size_t i = 0; i < goldenOutput.size(); ++i) { + CHECK(goldenOutput[i] == Approx(output[i])); + } + + REQUIRE(output.size() == goldenOutput.size()); +} + +TEST_CASE("Test SoftmaxF32") +{ + SECTION("Simple series") { + const std::vector<float> input { + 0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0 + }; + + const std::vector<float> expectedOutput { + 7.80134161e-05, 2.12062451e-04, + 5.76445508e-04, 1.56694135e-03, + 4.25938820e-03, 1.15782175e-02, + 3.14728583e-02, 8.55520989e-02, + 2.32554716e-01, 6.32149258e-01 + }; + + TestSoftmaxF32(input, expectedOutput); + } + + SECTION("Random series") { + const std::vector<float> input { + 0.8810943246170809, 0.5877587675947015, + 0.6841546454788743, 0.4155920960071594, + 0.9799415323651671, 0.5066432973545711, + 0.3846024252355448, 0.4568689569632123, + 0.3284413744557605, 0.49152323726213554 + }; + + const std::vector<float> expectedOutput { + 0.13329595, 0.09940837, + 0.10946799, 0.08368583, + 0.14714509, 0.09166319, + 0.08113220, 0.08721240, + 0.07670132, 0.09028766 + }; + + TestSoftmaxF32(input, expectedOutput); + } + + SECTION("Series with large STD") { + const std::vector<float> input { + 0.001, 1000.000 + }; + + const std::vector<float> expectedOutput { + 0.000, 1.000 + }; + + TestSoftmaxF32(input, expectedOutput); + } } |