diff options
author | alexander <alexander.efremov@arm.com> | 2021-03-26 21:42:19 +0000 |
---|---|---|
committer | Kshitij Sisodia <kshitij.sisodia@arm.com> | 2021-03-29 16:29:55 +0100 |
commit | 3c79893217bc632c9b0efa815091bef3c779490c (patch) | |
tree | ad06b444557eb8124652b45621d736fa1b92f65d /tests/common | |
parent | 6ad6d55715928de72979b04194da1bdf04a4c51b (diff) | |
download | ml-embedded-evaluation-kit-3c79893217bc632c9b0efa815091bef3c779490c.tar.gz |
Opensource ML embedded evaluation kit21.03
Change-Id: I12e807f19f5cacad7cef82572b6dd48252fd61fd
Diffstat (limited to 'tests/common')
-rw-r--r-- | tests/common/AppContextTest.cc | 63 | ||||
-rw-r--r-- | tests/common/ClassifierTests.cc | 76 | ||||
-rw-r--r-- | tests/common/ProfilerTests.cc | 61 | ||||
-rw-r--r-- | tests/common/SlidingWindowTests.cc | 266 |
4 files changed, 466 insertions, 0 deletions
diff --git a/tests/common/AppContextTest.cc b/tests/common/AppContextTest.cc new file mode 100644 index 0000000..42b142d --- /dev/null +++ b/tests/common/AppContextTest.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AppContext.hpp" + +#include <catch.hpp> + +TEST_CASE("Common: Application context") +{ + SECTION("Add primitive type Parameter") + { + arm::app::ApplicationContext context; + context.Set<uint32_t>("imgIndex", 0); + auto data = context.Get<uint32_t>("imgIndex"); + + REQUIRE(0 == data); + + } + + SECTION("Add object parameter") + { + arm::app::ApplicationContext context; + std::vector <std::string> vect{"a"}; + context.Set<std::vector <std::string>>("vect", vect); + auto data = context.Get<std::vector <std::string>>("vect"); + + REQUIRE(vect == data); + } + + SECTION("Add reference object parameter") + { + arm::app::ApplicationContext context; + std::vector <std::string> vect{"a"}; + context.Set<std::vector <std::string>&>("vect", vect); + auto data = context.Get<std::vector <std::string>&>("vect"); + + REQUIRE(vect == data); + } + + SECTION("Add object pointer parameter") + { + arm::app::ApplicationContext context; + std::vector <std::string>* vect = new std::vector <std::string>{"a"}; + context.Set<std::vector <std::string>*>("vect", vect); + auto data = context.Get<std::vector <std::string>*>("vect"); + + REQUIRE(vect == data); + delete(vect); + } +}
\ No newline at end of file diff --git a/tests/common/ClassifierTests.cc b/tests/common/ClassifierTests.cc new file mode 100644 index 0000000..f08a09a --- /dev/null +++ b/tests/common/ClassifierTests.cc @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Classifier.hpp" + +#include <catch.hpp> + +TEST_CASE("Common classifier") +{ + SECTION("Test invalid classifier") + { + TfLiteTensor* outputTens = nullptr; + std::vector <arm::app::ClassificationResult> resultVec; + arm::app::Classifier classifier; + REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5)); + } + + SECTION("Test valid classifier UINT8") + { + const int dimArray[] = {1, 1001}; + std::vector <std::string> labels(1001); + std::vector <uint8_t> outputVec(1001); + TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); + TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor( + outputVec.data(), dims, 1, 0, "test"); + TfLiteTensor* outputTensor = &tfTensor; + std::vector <arm::app::ClassificationResult> resultVec; + arm::app::Classifier classifier; + REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5)); + REQUIRE(5 == resultVec.size()); + } + + SECTION("Get classification results") + { + const int dimArray[] = {1, 1001}; + std::vector <std::string> labels(1001); + std::vector<uint8_t> outputVec(1001, static_cast<uint8_t>(5)); + TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); + TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor( + outputVec.data(), dims, 1, 0, "test"); + TfLiteTensor* outputTensor = &tfTensor; + + std::vector <arm::app::ClassificationResult> resultVec; + + /* Set the top five results. */ + std::vector<std::pair<uint32_t, uint8_t>> selectedResults { + {0, 8}, {20, 7}, {10, 7}, {15, 9}, {1000, 10}}; + + for (size_t i = 0; i < selectedResults.size(); ++i) { + outputVec[selectedResults[i].first] = selectedResults[i].second; + } + + arm::app::Classifier classifier; + REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5)); + REQUIRE(5 == resultVec.size()); + + REQUIRE(resultVec[0].m_labelIdx == 1000); + REQUIRE(resultVec[1].m_labelIdx == 15); + REQUIRE(resultVec[2].m_labelIdx == 0); + REQUIRE(resultVec[3].m_labelIdx == 20); + REQUIRE(resultVec[4].m_labelIdx == 10); + } +} diff --git a/tests/common/ProfilerTests.cc b/tests/common/ProfilerTests.cc new file mode 100644 index 0000000..caf492b --- /dev/null +++ b/tests/common/ProfilerTests.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Profiler.hpp" + +#include "AppContext.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include <catch.hpp> +#include <iostream> + + +TEST_CASE("Common: Test Profiler") +{ + hal_platform platform; + data_acq_module data_acq {}; + data_psn_module data_psn {}; + platform_timer timer {}; + + /* Initialise the HAL and platform. */ + hal_init(&platform, &data_acq, &data_psn, &timer); + hal_platform_init(&platform); + + /* An invalid profiler shouldn't be of much use. */ + arm::app::Profiler profilerInvalid {nullptr, "test_invalid"}; + REQUIRE(false == profilerInvalid.StartProfiling()); + REQUIRE(false == profilerInvalid.StopProfiling()); + + arm::app::Profiler profilerValid{&platform, "test_valid"}; + REQUIRE(true == profilerValid.StartProfiling()); + REQUIRE(true == profilerValid.StopProfiling()); + + std::string strProfile = profilerValid.GetResultsAndReset(); + REQUIRE(std::string::npos != strProfile.find("test_valid")); + +#if defined(CPU_PROFILE_ENABLED) + /* We should have milliseconds elapsed. */ + REQUIRE(std::string::npos != strProfile.find("ms")); +#endif /* defined(CPU_PROFILE_ENABLED) */ + + /* Abuse should fail: */ + REQUIRE(false == profilerValid.StopProfiling()); /* We need to start it first. */ + REQUIRE(true == profilerValid.StartProfiling()); /* Should be able to start it fine. */ + REQUIRE(false == profilerValid.StartProfiling()); /* Can't restart it without resetting. */ + profilerValid.Reset(); /* Reset. */ + REQUIRE(true == profilerValid.StartProfiling()); /* Can start it again now. */ + REQUIRE(true == profilerValid.StopProfiling()); /* Can start it again now. */ +} diff --git a/tests/common/SlidingWindowTests.cc b/tests/common/SlidingWindowTests.cc new file mode 100644 index 0000000..bfdb5b7 --- /dev/null +++ b/tests/common/SlidingWindowTests.cc @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AudioUtils.hpp" +#include "catch.hpp" + +TEST_CASE("Common: Slide long data") +{ + std::vector<int> test{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + SECTION("Fit the data") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + + for (int i = 0 ; i < 10; ++i) { + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == i + 1); + } + + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } + + SECTION("Fit the data stride> window") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 2, 3); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 1); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 4); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 7); + + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } + + SECTION("Fit the data stride < window") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 5, 1); + + for (int i = 0 ; i < 6; i++) { + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == i + 1); + } + + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } +} + + +TEST_CASE("Common: Slide data size 1") +{ + std::vector<int> test{1}; + + SECTION("Fit the data") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 1); + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } + + SECTION("Does not Fit the data because of big window") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 2, 1); + + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } + + SECTION("Does not Fit the data because of big stride") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 2); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 1); + REQUIRE(!slider.HasNext()); + REQUIRE(nullptr == slider.Next()); + } + +} + + +TEST_CASE("Common: Slide reset") +{ + SECTION("current range") + { + std::vector<int> test{1}; + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + int *saved = slider.Next(); + slider.Reset(); + + REQUIRE(slider.Next() == saved); + } + + SECTION("new range") + { + std::vector<int> test{1}; + std::vector<int> test2{100}; + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + slider.Next(); + slider.Reset(test2.data()); + + REQUIRE(*slider.Next() == 100); + } +} + + +TEST_CASE("Common: Slide fast forward") +{ + std::vector<int> test{1, 2, 3, 4, 5}; + + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + SECTION("at the beginning") { + slider.FastForward(3); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 4); + } + + SECTION("in the middle") + { + slider.Next(); + slider.FastForward(3); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 4); + } + + SECTION("at the end") + { + while(slider.HasNext()) { + slider.Next(); + } + slider.FastForward(3); + + REQUIRE(slider.HasNext()); + REQUIRE(*slider.Next() == 4); + } + + SECTION("out of the range") + { + slider.FastForward(100); + + REQUIRE(!slider.HasNext()); + REQUIRE(slider.Next() == nullptr); + } +} + + +TEST_CASE("Common: Slide Index") +{ + std::vector<int> test{1, 2, 3, 4, 5}; + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + REQUIRE(slider.Index() == 0); + for (int i = 0; i < 5; i++) { + slider.Next(); + REQUIRE(slider.Index() == i); + } +} + + +TEST_CASE("Common: Total strides") +{ + std::vector<int> test{1, 2, 3, 4, 5}; + + SECTION("Element by element") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 1); + REQUIRE(slider.TotalStrides() == 4 ); + } + + SECTION("Step through element") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 1, 2); + REQUIRE(slider.TotalStrides() == 2 ); + } + + SECTION("Window = data") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 5, 2); + REQUIRE(slider.TotalStrides() == 0 ); + } + + SECTION("Window > data") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 6, 2); + REQUIRE(slider.TotalStrides() == 0 ); + } + + SECTION("Window < data, not enough for the next stride") + { + auto slider = arm::app::audio::SlidingWindow<int>(test.data(), test.size(), 4, 2); + REQUIRE(slider.TotalStrides() == 0 ); + } +} + + +TEST_CASE("Common: Next window data index") +{ + std::vector<int> test{1, 2, 3, 4, 5}; + + /* Check we get the correct index returned */ + SECTION("Stride 1") + { + auto slider = arm::app::audio::ASRSlidingWindow<int>(test.data(), test.size(), 1, 1); + REQUIRE(slider.NextWindowStartIndex() == 0); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 1); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 2); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 3); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 4); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 5); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 5); + } + + SECTION("Stride 2") + { + auto slider = arm::app::audio::ASRSlidingWindow<int>(test.data(), test.size(), 1, 2); + REQUIRE(slider.NextWindowStartIndex() == 0); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 2); + REQUIRE(slider.NextWindowStartIndex() == 2); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 4); + } + + SECTION("Stride 3") + { + auto slider = arm::app::audio::ASRSlidingWindow<int>(test.data(), test.size(), 1, 3); + REQUIRE(slider.NextWindowStartIndex() == 0); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 3); + REQUIRE(slider.NextWindowStartIndex() == 3); + slider.Next(); + REQUIRE(slider.NextWindowStartIndex() == 6); + REQUIRE(!slider.HasNext()); + REQUIRE(slider.Next() == nullptr); + REQUIRE(slider.NextWindowStartIndex() == 6); + } +} |