summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2022-01-04 13:37:53 +0000
committerKshitij Sisodia <kshitij.sisodia@arm.com>2022-01-04 13:37:53 +0000
commitb178b282e27e731dddb4f5a950fb5694edc4514a (patch)
tree01847ce2c3479146afdbc12c7d5bba9782cdc49d /tests
parentb5b32d3e6188cc7126b3181e3be328d6583c5967 (diff)
downloadml-embedded-evaluation-kit-b178b282e27e731dddb4f5a950fb5694edc4514a.tar.gz
MLECO-2834: Tests for Softmax function.
Added tests for recently added Softmax function in PlatformMath module. Change-Id: Iacf1f4eaf33a92e1d42275000765e7152d17176b
Diffstat (limited to 'tests')
-rw-r--r--tests/common/PlatformMathTests.cpp75
1 files changed, 74 insertions, 1 deletions
diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp
index 02653f2..2155886 100644
--- a/tests/common/PlatformMathTests.cpp
+++ b/tests/common/PlatformMathTests.cpp
@@ -597,6 +597,79 @@ TEST_CASE("Test ComplexMagnitudeSquaredF32")
* of the input vector with output results */
arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen);
- for (size_t i = 0; i < outputLen; i++)
+ for (size_t i = 0; i < outputLen; i++) {
CHECK (expectedResult[i] == Approx(output[i]));
+ }
+}
+
+/**
+ * @brief Simple function to test the Softmax function
+ *
+ * @param input Input vector
+ * @param goldenOutput Expected output vector
+ */
+static void TestSoftmaxF32(
+ const std::vector<float>& input,
+ const std::vector<float>& goldenOutput)
+{
+ std::vector<float> output = input; /* Function modifies the vector in-place */
+ arm::app::math::MathUtils::SoftmaxF32(output);
+
+ for (size_t i = 0; i < goldenOutput.size(); ++i) {
+ CHECK(goldenOutput[i] == Approx(output[i]));
+ }
+
+ REQUIRE(output.size() == goldenOutput.size());
+}
+
+TEST_CASE("Test SoftmaxF32")
+{
+ SECTION("Simple series") {
+ const std::vector<float> input {
+ 0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0
+ };
+
+ const std::vector<float> expectedOutput {
+ 7.80134161e-05, 2.12062451e-04,
+ 5.76445508e-04, 1.56694135e-03,
+ 4.25938820e-03, 1.15782175e-02,
+ 3.14728583e-02, 8.55520989e-02,
+ 2.32554716e-01, 6.32149258e-01
+ };
+
+ TestSoftmaxF32(input, expectedOutput);
+ }
+
+ SECTION("Random series") {
+ const std::vector<float> input {
+ 0.8810943246170809, 0.5877587675947015,
+ 0.6841546454788743, 0.4155920960071594,
+ 0.9799415323651671, 0.5066432973545711,
+ 0.3846024252355448, 0.4568689569632123,
+ 0.3284413744557605, 0.49152323726213554
+ };
+
+ const std::vector<float> expectedOutput {
+ 0.13329595, 0.09940837,
+ 0.10946799, 0.08368583,
+ 0.14714509, 0.09166319,
+ 0.08113220, 0.08721240,
+ 0.07670132, 0.09028766
+ };
+
+ TestSoftmaxF32(input, expectedOutput);
+ }
+
+ SECTION("Series with large STD") {
+ const std::vector<float> input {
+ 0.001, 1000.000
+ };
+
+ const std::vector<float> expectedOutput {
+ 0.000, 1.000
+ };
+
+ TestSoftmaxF32(input, expectedOutput);
+ }
}