/* * SPDX-FileCopyrightText: Copyright 2021 - 2023 Arm Limited and/or its affiliates * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "PlatformMath.hpp" #include #include #include TEST_CASE("Test CosineF32") { /*Test Constants: */ std::vector inputA{ 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0 }; std::vector expectedResult{ 1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308, 0.9921147013144779, 0.9876883405951378, 0.9822872507286887, 0.9759167619387474, 0.9685831611286311, 0.9602936856769431, 0.9510565162951535, 0.9408807689542255, 0.9297764858882515, 0.9177546256839811, 0.9048270524660195, 0.8910065241883679, 0.8763066800438636, 0.8607420270039436, 0.8443279255020151, 0.8270805742745618, 0.8090169943749475, 0.7901550123756904, 0.7705132427757893, 0.7501110696304596, 0.7289686274214116, 0.7071067811865476, 0.6845471059286886, 0.6613118653236518, 0.6374239897486896, 0.6129070536529766, 0.5877852522924731, 0.5620833778521306, 0.5358267949789965, 0.5090414157503712, 0.48175367410171516, 0.4539904997395468, 0.42577929156507266, 0.39714789063478056, 0.3681245526846781, 0.3387379202452915, 0.30901699437494745, 0.2789911060392295, 0.24868988716485496, 0.2181432413965427, 0.18738131458572474, 0.15643446504023092, 0.12533323356430426, 0.0941083133185145, 0.06279051952931353, 0.031410759078128396, 6.123233995736766e-17, -0.03141075907812828, -0.0627905195293134, -0.09410831331851438, -0.12533323356430437, -0.15643446504023104, -0.18738131458572482, -0.21814324139654234, -0.24868988716485463, -0.27899110603922916, -0.30901699437494734, -0.33873792024529137, -0.368124552684678, -0.39714789063478045, -0.4257792915650727, -0.4539904997395467, -0.48175367410171543, -0.5090414157503713, -0.5358267949789969, -0.5620833778521304, -0.587785252292473, -0.6129070536529763, -0.6374239897486897, -0.6613118653236517, -0.6845471059286887, -0.7071067811865475, -0.7289686274214113, -0.7501110696304596, -0.7705132427757891, -0.7901550123756904, -0.8090169943749473, -0.8270805742745619, -0.8443279255020149, -0.8607420270039435, -0.8763066800438634, -0.8910065241883678, -0.9048270524660194, -0.9177546256839811, -0.9297764858882513, -0.9408807689542255, -0.9510565162951535, -0.9602936856769431, -0.9685831611286311, -0.9759167619387474, -0.9822872507286886, -0.9876883405951377, -0.9921147013144778, -0.99556196460308, -0.9980267284282716, -0.9995065603657316, -1.0 }; float tolerance = 10e-7; for (size_t i = 0; i < inputA.size(); i++) { CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::CosineF32(M_PI*inputA[i])).margin(tolerance)); } } TEST_CASE("Test SineF32") { /*Test Constants: */ std::vector inputA{ 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0 }; std::vector expectedResult{ 0.0, 0.03141075907812829, 0.06279051952931337, 0.09410831331851431, 0.12533323356430426, 0.15643446504023087, 0.1873813145857246, 0.21814324139654256, 0.2486898871648548, 0.2789911060392293, 0.3090169943749474, 0.33873792024529137, 0.3681245526846779, 0.3971478906347806, 0.4257792915650727, 0.45399049973954675, 0.4817536741017153, 0.5090414157503713, 0.5358267949789967, 0.5620833778521306, 0.5877852522924731, 0.6129070536529764, 0.6374239897486896, 0.6613118653236518, 0.6845471059286886, 0.7071067811865475, 0.7289686274214116, 0.7501110696304596, 0.7705132427757893, 0.7901550123756903, 0.8090169943749475, 0.8270805742745618, 0.8443279255020151, 0.8607420270039436, 0.8763066800438637, 0.8910065241883678, 0.9048270524660196, 0.9177546256839811, 0.9297764858882513, 0.9408807689542255, 0.9510565162951535, 0.960293685676943, 0.9685831611286311, 0.9759167619387473, 0.9822872507286886, 0.9876883405951378, 0.9921147013144779, 0.99556196460308, 0.9980267284282716, 0.9995065603657316, 1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308, 0.9921147013144778, 0.9876883405951377, 0.9822872507286886, 0.9759167619387474, 0.9685831611286312, 0.9602936856769431, 0.9510565162951536, 0.9408807689542255, 0.9297764858882513, 0.9177546256839813, 0.9048270524660195, 0.8910065241883679, 0.8763066800438635, 0.8607420270039436, 0.844327925502015, 0.827080574274562, 0.8090169943749475, 0.7901550123756905, 0.7705132427757893, 0.7501110696304597, 0.7289686274214114, 0.7071067811865476, 0.6845471059286888, 0.6613118653236518, 0.6374239897486899, 0.6129070536529764, 0.5877852522924732, 0.5620833778521305, 0.535826794978997, 0.5090414157503714, 0.4817536741017156, 0.45399049973954686, 0.4257792915650729, 0.3971478906347806, 0.36812455268467814, 0.3387379202452913, 0.3090169943749475, 0.2789911060392291, 0.24868988716485482, 0.21814324139654231, 0.18738131458572502, 0.15643446504023098, 0.12533323356430454, 0.09410831331851435, 0.06279051952931358, 0.031410759078128236, 1.2246467991473532e-16 }; float tolerance = 10e-4; for (size_t i = 0; i < inputA.size(); i++) { CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SineF32(M_PI*inputA[i])).margin(tolerance)); } } TEST_CASE("Test SqrtF32") { /*Test Constants: */ std::vector inputA{0,1,2,9,M_PI}; size_t len = inputA.size(); std::vector expectedResult{0, 1, 1.414213562, 3, 1.772453851 }; for (size_t i=0; i < len; i++){ CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SqrtF32(inputA[i]))); } } TEST_CASE("Test MeanF32") { /* Test Constants: */ std::vector input {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000}; /* Manually calculated mean of above vector */ float expectedResult = 0.100; CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size()))); /* Mean of 0 */ std::vector input2{1, 2, -1, -2}; float expectedResult2 = 0.0f; CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size()))); /* All 0s */ std::vector input3 = std::vector(9, 0); float expectedResult3 = 0.0f; CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size()))); /* All 1s */ std::vector input4 = std::vector(9, 1); float expectedResult4 = 1.0f; CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size()))); } TEST_CASE("Test StdDevF32") { /*Test Constants: */ /* Normally distributed sample data generated by numpy normal library */ std::vector input {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434, 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547, -0.69166075,-0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646, -1.11731035, 0.2344157, 1.65980218, 0.74204416, -0.19183555, -0.88762896, -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514, 0.12015895, 0.61720311 }; uint32_t inputLen = input.size(); /*Calculate mean using std library to avoid dependency on MathUtils::MeanF32 */ float mean = (std::accumulate(input.begin(), input.end(), 0.0f))/float(inputLen); float output = arm::app::math::MathUtils::StdDevF32(input.data(), inputLen, mean); /*Manually calculated standard deviation of above vector*/ float expectedResult = 0.969589282958136; CHECK (expectedResult == Approx(output)); /* All 0s should have 0 std dev. */ std::vector input2 = std::vector(4, 0); float expectedResult2 = 0.0f; CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f))); /* All 1s should have 0 std dev. */ std::vector input3 = std::vector(4, 1); float expectedResult3 = 0.0f; CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f))); /* Manually calclualted std value */ std::vector input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size()); float expectedResult4 = 2.872281323; CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2))); } TEST_CASE("Test FFT32") { constexpr size_t nElem = 512; /*Test Constants: */ std::vector input_zeros(nElem, 0); std::vector input_ones(nElem, 1); /* Random numbers generated using numpy rand with range [0:1] */ std::vector input_random{ 0.42333686, 0.6547418, 0.8933691, 0.91466254, 0.5992143, 0.99474055, 0.97750413, 0.97160685, 0.72718734, 0.8699537, 0.643911, 0.09764466, 0.0050113136, 0.46823388, 0.13709934, 0.44892532, 0.59728205, 0.04055081, 0.5579888, 0.18445836, 0.66469765, 0.82715863, 0.91934484, 0.7844356, 0.23489648, 0.021708783, 0.67819905, 0.75761676, 0.48374954, 0.14006922, 0.87082034, 0.7694296, 0.80479276, 0.8241704, 0.95917296, 0.5758142, 0.16839339, 0.34290153, 0.5846108, 0.6878044, 0.1067114, 0.5198196, 0.4356897, 0.68049103, 0.12480807, 0.3538696, 0.06067087, 0.056964435, 0.5382167, 0.07761527, 0.6673144, 0.9045368, 0.11050189, 0.03530183, 0.07864744, 0.98752064, 0.014321936, 0.101833574, 0.43293256, 0.87102246, 0.52411795, 0.90232223, 0.49560344, 0.6803092, 0.2908511, 0.14653015, 0.99105513, 0.7057098, 0.09623502, 0.039713606, 0.88669086, 0.56018597, 0.90632766, 0.99241334, 0.18748309, 0.38991618, 0.6359827, 0.05665585, 0.732304, 0.2703365, 0.19014524, 0.5017947, 0.78862536, 0.81253093, 0.35050204, 0.2832596, 0.65221876, 0.59856164, 0.42758793, 0.78865635, 0.30943435, 0.93780816, 0.62568265, 0.35397422, 0.84209913, 0.48590583, 0.34837773, 0.5811646, 0.42924216, 0.26692122, 0.030709852, 0.84459823, 0.09085059, 0.29297647, 0.48539516, 0.33488297, 0.7877257, 0.8728821, 0.28454545, 0.7109578, 0.86097074, 0.8536262, 0.4978063, 0.5760398, 0.77506036, 0.7716988, 0.27041402, 0.52340513, 0.2055419, 0.8728235, 0.13492358, 0.79122984, 0.52998376, 0.33897072, 0.6426309, 0.8766521, 0.89287037, 0.74047667, 0.42341164, 0.67437655, 0.4682156, 0.67123246, 0.54287183, 0.3580476, 0.94756556, 0.2699457, 0.6131569, 0.75043845, 0.8115012, 0.49610943, 0.7108478, 0.90941435, 0.02233071, 0.37346774, 0.33732748, 0.46691266, 0.35784695, 0.39391598, 0.8556212, 0.884142, 0.11730601, 0.550112, 0.31513855, 0.69654715, 0.58585805, 0.4493127, 0.78515726, 0.8176612, 0.9846698, 0.32842383, 0.41843212, 0.48470423, 0.6757128, 0.95876855, 0.5989163, 0.13587572, 0.72886884, 0.88291156, 0.34402263, 0.66211045, 0.86188424, 0.21498202, 0.26397392, 0.67372984, 0.91386956, 0.7339788, 0.91308993, 0.1953016, 0.1539217, 0.214701, 0.58234113, 0.8019992, 0.63969976, 0.041050985, 0.7293308, 0.26341477, 0.54768014, 0.97596467, 0.12385198, 0.44149798, 0.5519762, 0.1697347, 0.577215, 0.8213594, 0.47874716, 0.64515114, 0.61467725, 0.18463866, 0.23890929, 0.51052976, 0.16807361, 0.53142565, 0.2414274, 0.41690814, 0.98815554, 0.6245643, 0.9477003, 0.24780034, 0.82469565, 0.8614785, 0.9565832, 0.062440686, 0.9710724, 0.039196696, 0.11030199, 0.35234734, 0.02065066, 0.12832293, 0.7328055, 0.48924434, 0.17247158, 0.5769348, 0.44146806, 0.53575355, 0.17258933, 0.6980237, 0.86494404, 0.50573164, 0.5033998, 0.71199447, 0.41353586, 0.26767612, 0.3789118, 0.046621118, 0.58491063, 0.22861995, 0.03134273, 0.53280216, 0.23382367, 0.07748905, 0.96875405, 0.6613716, 0.64087844, 0.8377165, 0.051519375, 0.68997836, 0.3776376, 0.43362603, 0.5358754, 0.51419014, 0.12823892, 0.26574057, 0.508808, 0.15734084, 0.78327274, 0.5045347, 0.5445746, 0.89297736, 0.8531272, 0.91270804, 0.87429863, 0.3965137, 0.13544834, 0.74269205, 0.80592203, 0.045050766, 0.13362087, 0.17090783, 0.02873757, 0.99339336, 0.6394376, 0.48203012, 0.70598215, 0.37082237, 0.39792424, 0.89938444, 0.312602, 0.48755112, 0.18220617, 0.17303479, 0.31954846, 0.78080165, 0.1755106, 0.68262285, 0.84665287, 0.8520143, 0.8459509, 0.39417005, 0.30087698, 0.81362164, 0.61927587, 0.32739028, 0.9023775, 0.27578092, 0.6830477, 0.15842387, 0.8473049, 0.43057114, 0.2019703, 0.20560141, 0.6237757, 0.60283095, 0.27645138, 0.26605442, 0.27985683, 0.41353813, 0.85139906, 0.71711886, 0.5444832, 0.73613757, 0.7397004, 0.7406752, 0.41016674, 0.31896713, 0.4541723, 0.2795807, 0.47941738, 0.00504193, 0.89091027, 0.8097144, 0.63033766, 0.37252298, 0.9132861, 0.5102532, 0.04104481, 0.30368647, 0.21573475, 0.99520445, 0.5047808, 0.6868845, 0.99881023, 0.30377692, 0.2554386, 0.47201005, 0.11120686, 0.10077732, 0.1853349, 0.49159425, 0.3938629, 0.8989509, 0.9887155, 0.698771, 0.695701, 0.78368753, 0.52537227, 0.19451462, 0.3659248, 0.1968508, 0.7751828, 0.33103722, 0.40406147, 0.37832898, 0.68663514, 0.32225925, 0.41771907, 0.034218453, 0.42808908, 0.20685343, 0.1861495, 0.045986768, 0.8532299, 0.17200677, 0.44670314, 0.56831235, 0.5388232, 0.5430553, 0.69175136, 0.6462231, 0.42827028, 0.10050113, 0.30627027, 0.9967943, 0.6684778, 0.5928422, 0.63392985, 0.99123496, 0.79301435, 0.7936309, 0.42839453, 0.39781123, 0.22329247, 0.0122212395, 0.2807108, 0.19812097, 0.5576105, 0.115653396, 0.3732018, 0.7622857, 0.19847734, 0.5310287, 0.7298145, 0.5518292, 0.9117333, 0.13215758, 0.33716795, 0.42372775, 0.6779287, 0.35799992, 0.097887225, 0.20171605, 0.9948177, 0.1829232, 0.80349857, 0.9807098, 0.22959666, 0.67322475, 0.63094735, 0.93454355, 0.15962408, 0.04335433, 0.47104993, 0.36784375, 0.45258796, 0.93415564, 0.1655446, 0.7195017, 0.76236975, 0.3846913, 0.01330617, 0.84716374, 0.1227003, 0.65102947, 0.6632434, 0.3728453, 0.4222391, 0.6942989, 0.16014872, 0.10798196, 0.94033676, 0.026525471, 0.8379024, 0.5484514, 0.13500613, 0.22919805, 0.7001831, 0.6573261, 0.38086265, 0.8725666, 0.35077834, 0.28415123, 0.42283052, 0.668379, 0.9769895, 0.37621376, 0.646407, 0.11188069, 0.17129017, 0.7441628, 0.25617477, 0.7751679, 0.8565412, 0.67631435, 0.45213568, 0.61896557, 0.3387995, 0.51607716, 0.60779697, 0.16428445, 0.5080923, 0.13012086, 0.61184275, 0.7690249, 0.9578811, 0.67365676, 0.16241212, 0.97157824, 0.5595742, 0.75936574, 0.6043881, 0.2149638, 0.4925318, 0.58727825, 0.97953695, 0.01605968, 0.2819307, 0.6448378, 0.4265335, 0.661541, 0.3976571, 0.40607136, 0.46425515, 0.2055872, 0.2716193, 0.4132582, 0.8372537, 0.37787434, 0.082228854, 0.7985557, 0.9718134, 0.35222608, 0.4853643, 0.2569464, 0.14783978, 0.4889042, 0.62900156, 0.19994198, 0.4618481, 0.21673755, 0.51749533, 0.1260157, 0.83759904, 0.36438805, 0.6704668, 0.22010763, 0.2359318, 0.53004104, 0.9723652, 0.91218954, 0.9153926, 0.48207277, 0.34850466, 0.8939421}; std::vector> input_vectors{input_zeros, input_ones, input_random}; std::vector output_zeros(nElem, 0); std::vector output_ones(nElem, 0); std::vector output_random(nElem, 0); std::vector> output_vectors{output_zeros, output_ones, output_random}; std::vector expected_result_zeros(nElem, 0); std::vector expected_result_ones(nElem, 0); expected_result_ones[0] = static_cast(nElem); /* Values are stored as [real0, realN/2, real1, im1, real2, im2, ...] */ std::vector expected_result_random{ 259.510161, 2.59796867, 2.55982143, -5.91349888, -1.80049237, 1.09902763, 4.0094324, 7.76684892, 4.32617219, -4.33636417, 2.98128463, -0.83763449, 2.92973078, -3.75655459, 2.27203161, -4.61106145, 5.55562176, -4.71880166, 2.13693416, -2.20496619, 2.31174036, -3.52991041, -0.61687068, 2.43455407, 4.76317833, 8.66518565, 1.72350562, -1.33641312, 5.82836675, 0.89396187, 8.15031483, 4.34599034, 5.99780199, 2.94900065, -0.0234462045, -5.03789597, 12.190702, -6.47012928, 1.24434715, 0.23621713, 0.920921279, -9.20510398, -0.773267254, -3.72141078, 6.28883709, 4.00634065, 4.46491682, 0.74625307, -0.587506158, -4.22833058, 3.15189786, -1.82518672, -6.9378226, 2.4170692, 3.23045185, 3.33383799, 0.0510059531, -3.4233929, -2.91651323, -0.0258584, 5.84499843, -9.51454903, -14.9214047, 5.52200123, 4.5217959, -7.08268703, 0.51677542, -2.90878759, 5.04314682, -1.16928599, -10.7329243, -0.2719951, -3.95269565, -2.32475678, -4.11031641, -2.20538835, -0.589005095, -5.65483456, -10.8927018, 5.74801823, -5.72520347, -3.94970165, -0.518407515, 1.23622633, 9.56297959, 0.24424306, -3.74306351, 9.63476301, -4.74493837, -0.35443496, 6.54760504, 1.16188913, -13.341695, 7.19088609, -2.560458, -0.49557866, 2.93460322, 6.91076746, -0.284779221, 6.59958391, 1.70963995, -7.67293252, 4.1850079, -0.14627552, -1.24855113, -1.43322867, 0.360644904, -4.11521374, -10.0628421, 2.87531563, -9.34809732, -0.58251846, 3.61799848, -5.10288284, -5.96239076, 1.99792128, -0.0783229243, 1.81741166, 2.32709681, 0.68487206, -3.08398468, 1.5177629, 1.41015388, 4.51146401, 1.90769911, 0.56093423, 3.58389141, -0.14974575, -2.20163907, -1.62177814, -1.91904127, 1.94645907, -0.13772293, -4.30291678, -2.61843435, 0.12691241, 3.28959117, 4.7309582, -2.93995652, 0.39835926, 2.89711768, 1.42284586, -5.13129145, -7.26477374, 3.74616158, -2.59659457, 3.8574875, -2.93737277, 3.17748694, -4.45041455, -2.68466437, 1.37377726, 1.60008368, 1.63787578, -1.95661401, -6.34937202, -2.62744282, -5.20892662, 0.890553959, -6.37113573, 2.35885332, -10.04547561, 0.329866159, 1.89217741, 0.882516491, -3.53298728, -2.22525608, -5.64794388, -5.19226843, 2.5971315, 4.49346648, -0.20428409, -6.14851885, -11.90893875, 3.75899776, -1.86910056, 2.78518535, -2.6359501, -2.13423317, 4.86509946, -2.37625499, 7.42404308, 6.71175474, 6.06191618, 2.59014379, 4.76329698, 9.19140042, 9.69149015, 3.33307819, -4.03094924, 2.12988453, -0.15820258, -10.3422801, -3.04462388, 3.59852152, -2.00887343, -3.69998656, 0.90050102, 0.679959099, -1.88604949, 1.24235316, 0.41309537, 6.13876866, -7.1040085, 6.17728674, 1.91667103, -1.32895472, -0.17674661, -6.94720428, 3.10502593, -2.33990738, -1.27840434, 3.2144252, 2.14102714, 2.37498837, 3.8158066, -2.24107675, -5.52527559, -2.9569793, -0.50367608, 3.01687661, 7.08195792, 6.7860479, -3.94154162, 2.24402195, 4.60132638, 3.42211139, 4.17689039, -1.17277194, 2.15404472, -2.3748193, 1.42611867, -0.463033506, 3.21563035, 1.38662123, 3.98598717, -3.75283402, -2.47600433, -1.97290542, 2.83361487, -0.845662834, 5.57411581, -0.972981483, -11.394208, 1.88220611, -0.80225125, -0.434295854, -8.2954126, 3.81795409, -3.17146, -4.61994107, -1.59820505, -5.98834455, -4.93129451, -0.513862996, -0.15649305, -5.59094391, 6.25244435, -6.59974456, 13.17193115, 4.48609092, 1.64741879, 7.40985006, 0.44896188, 3.81058449, -0.76425931, -5.47938416, 4.01447941, -3.21535548, -1.45542238, 0.72274083, -0.23983128, -4.32373034, 0.1337671, -5.89365226, -3.18756318, 7.90979161, 5.27570134, -3.43094553, -6.00826981, 1.17932561, -3.50027177, 0.181306385, 1.1062498, 0.723650536, -1.55500613, -3.88047911, -2.43746762, -6.81565579, 2.16343352, 2.46366137, 2.38704469, -2.55106395, 6.5091449, -2.06510578, 11.11320924, 2.06649835, -1.05026064, 1.63564303, -0.04638729, 1.45053876, 0.43730146, 1.25027939, 0.79932743, 2.81088838, 6.95136058, -4.41417255, 2.89610628, 1.15426258, -2.60704937, -2.77744882, 4.12872365, -2.98288336, 6.75607352, 2.36553382, -2.10540332, -7.30042988, -5.44897893, 3.44048454, 4.29726231, 2.181995, -0.80126759, -4.04051175, 4.57584864, 0.956312116, -4.45183318, 3.42348929, -9.84138181, 8.69604433, -6.6481311, 0.468232735, 1.41031176, -1.240857, 10.61672181, 0.356591473, 10.51631421, -0.99743547, 2.72157537, 8.63583929, -2.19404252, -1.53605811, -4.41068581, 2.05371873, 1.25665769, 1.65289503, 4.52520582, -0.535062642, 0.82084677, -11.0079476, -5.09361474, 9.63129107, 3.90056638, -4.19779738, 0.06565745, 2.42526917, 2.5854233, -3.66709357, 3.80502971, -0.101489353, -6.85423228, -13.9361494, -0.43904617, 6.01800968, -1.30751495, -4.75122234, 2.74740671, 5.54971138, 9.43409003, -0.994733058, -3.0096825, -8.60263376, 0.36653762, 3.53318614, -2.69194556, -8.9514574, -4.71570923, 5.15417709, 2.68645385, -2.78042293, 8.21739385, 0.590225003, 2.13319153, 1.72158888, 0.18114627, 3.92269446, 3.3525857, -3.40313825, 4.39280934, -1.70368966, 1.29121245, -3.11326453, -1.85941318, 4.57078881, 0.72531039, 6.00445664, 4.9588524, -3.32944491, -0.02080722, -7.42374632, -3.23290026, -0.81614579, 3.55935439, -0.619206533, 2.42859073, 2.21486456, 3.76402487, 3.90930695, -3.61610186, -0.812712547, -14.63377988, 1.14460823, -3.14089899, 3.18097435, 1.21957751, 2.85181833, 0.89990235, -4.32147361, -5.54219361, -1.12253677, 2.96141081, -4.4257707, 3.17282306, 4.9174671, 1.16977744, -4.55148089, -2.82520179, -1.71684103, 1.91487668, 0.770726836, 0.78534837, -5.91048566, -4.8288477, -1.35560162, -3.60938315, 1.15812301, 2.44299541, 1.3611519, -5.40950935, 7.08292127, 0.27720591, -0.160210828, 2.75862348, -1.57403782, 9.97207524, -2.08957576, 8.70299964, -5.33004663, 4.1547783, 3.51580675, -5.10788085, 4.37938353, -3.73449894, 1.44673271, 0.51941469, 0.852232446, -1.1134965, -1.43972745, -1.62952127, -2.50759973, 1.19012213, 0.572772282, -2.71833059, -6.8471899, 4.2621535, 1.58954734, -0.53827818, -0.144624396, 7.63866979, 0.410423977, -2.4785678, -5.02681867, -2.03469811, 0.959505727, 2.68589705, 3.20889444, -10.76452533, -3.84771551, 2.49189796, -3.19895938, -3.49948794, -2.6723897, 5.11386526, -3.85957031, -1.40741978, 0.176663166, -11.7111276, 0.639997364, -1.30321198, 3.20767633, 1.65750671, -11.6187257, -4.36634782, -3.18675281, -4.89279155, -4.08760307, 2.19269283, -1.5892487, 0.17948212, 4.81376107, 2.01871001, -0.324211095, -0.2790092, 1.12603878, -3.61503491, -2.86982317, -3.03634532, 8.0771391, 2.21302089, 2.91496011, -2.58564072, 0.0, 0.0}; std::vector> expected_results_vectors{ expected_result_zeros, expected_result_ones, expected_result_random}; arm::app::math::FftInstance fftInstance; /* Iterate over each of the input vectors, calculate FFT and compare with corresponding expected_results vectors */ for (size_t j = 0; j < input_vectors.size(); j++) { uint16_t fftLen = input_vectors[j].size(); arm::app::math::MathUtils::FftInitF32(fftLen, fftInstance); arm::app::math::MathUtils::FftF32(input_vectors[j], output_vectors[j], fftInstance); const float tolerance = 10e-4; for (size_t i = 0; i < fftLen/2; i++) { CHECK(output_vectors[j][i] == Approx(expected_results_vectors[j][i]).margin(tolerance)); } } /* Test inverse FFTs using the forward FFT for complex numbers. * IFFT(XVec) = (1/N)(Conj(FFT(Conj(Xvec)))) */ for (size_t j = 0; j < input_vectors.size(); j++) { const uint16_t fftLen = input_vectors[j].size(); const size_t inputSz = fftLen * 2; /* This vector will populate the input for FFT for complex numbers. */ std::vector inputWithConjugates(inputSz); /* We expect the output of this test to return the original input. */ std::vector expectedOutputVector = input_vectors[j]; /* Placeholder for output vector. */ std::vector outputVector(inputWithConjugates.size()); /* Populate the 0 and N/2 elements (these will be real numbers * only - no imaginary parts. */ inputWithConjugates[0] = expected_results_vectors[j][0]; inputWithConjugates[fftLen] = expected_results_vectors[j][1]; /* Populate the rest of the elements - conjugates of the original for the left mirror * and the right side with what the left would have been, i.e., * conjugate(conjugate(X_left)). */ for (size_t i = 2; i < fftLen; i += 2) { inputWithConjugates[i] = expected_results_vectors[j][i]; inputWithConjugates[i + 1] = 0 - expected_results_vectors[j][i + 1]; inputWithConjugates[fftLen + i] = expected_results_vectors[j][fftLen - i]; inputWithConjugates[fftLen + i + 1] = expected_results_vectors[j][fftLen - i + 1]; } arm::app::math::MathUtils::FftInitF32( fftLen, fftInstance, arm::app::math::FftType::complex); arm::app::math::MathUtils::FftF32(inputWithConjugates, outputVector, fftInstance); const float tolerance = 0.1; for (size_t i = 0; i < expectedOutputVector.size(); i++) { /* The number returned here will be nElem times the output. */ CHECK(outputVector[i * 2] / static_cast(nElem) == Approx(expectedOutputVector[i]).margin(tolerance)); /* The imaginary part here should be close to 0 as the original input * we supplied was real. */ CHECK(outputVector[i * 2 + 1] / nElem == Approx(0.f).margin(tolerance)); } } } TEST_CASE("Test VecLogarithmF32") { /*Test Constants: */ std::vector input = { 0.1e-10, 0.5, 1, M_PI, M_E }; std::vector expectedResult = {-25.328436, -0.693147181, 0, 1.144729886, 1}; std::vector output(input.size()); arm::app::math::MathUtils::VecLogarithmF32(input,output); for (size_t i = 0; i < input.size(); i++) CHECK (expectedResult[i] == Approx(output[i])); } TEST_CASE("Test DotProductF32") { /*Test Constants: */ std::vector inputA {1,1,1,0,0,0}; std::vector inputB {0,0,0,1,1,1}; uint32_t len = inputA.size(); float expectedResult = 0; float dot_prod = arm::app::math::MathUtils::DotProductF32(inputA.data(), inputB.data(), len); CHECK(dot_prod == expectedResult); } TEST_CASE("Test ComplexMagnitudeSquaredF32") { /*Test Constants: */ std::vector input {0.0, 0.0, 0.5, 0.5,1,1}; size_t inputLen = input.size(); std::vector expectedResult {0.0, 0.5, 2,}; size_t outputLen = inputLen/2; std::vectoroutput(outputLen); /* Pass pointers to input/output vectors as this function over-writes the first half * of the input vector with output results */ arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen); for (size_t i = 0; i < outputLen; i++) { CHECK (expectedResult[i] == Approx(output[i])); } } /** * @brief Simple function to test the Softmax function * * @param input Input vector * @param goldenOutput Expected output vector */ static void TestSoftmaxF32( const std::vector& input, const std::vector& goldenOutput) { std::vector output = input; /* Function modifies the vector in-place */ arm::app::math::MathUtils::SoftmaxF32(output); for (size_t i = 0; i < goldenOutput.size(); ++i) { CHECK(goldenOutput[i] == Approx(output[i])); } REQUIRE(output.size() == goldenOutput.size()); } TEST_CASE("Test SoftmaxF32") { SECTION("Simple series") { const std::vector input { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 }; const std::vector expectedOutput { 7.80134161e-05, 2.12062451e-04, 5.76445508e-04, 1.56694135e-03, 4.25938820e-03, 1.15782175e-02, 3.14728583e-02, 8.55520989e-02, 2.32554716e-01, 6.32149258e-01 }; TestSoftmaxF32(input, expectedOutput); } SECTION("Random series") { const std::vector input { 0.8810943246170809, 0.5877587675947015, 0.6841546454788743, 0.4155920960071594, 0.9799415323651671, 0.5066432973545711, 0.3846024252355448, 0.4568689569632123, 0.3284413744557605, 0.49152323726213554 }; const std::vector expectedOutput { 0.13329595, 0.09940837, 0.10946799, 0.08368583, 0.14714509, 0.09166319, 0.08113220, 0.08721240, 0.07670132, 0.09028766 }; TestSoftmaxF32(input, expectedOutput); } SECTION("Series with large STD") { const std::vector input { 0.001, 1000.000 }; const std::vector expectedOutput { 0.000, 1.000 }; TestSoftmaxF32(input, expectedOutput); } }