/* * SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Classifier.hpp" #include template void test_classifier_result(std::vector>& selectedResults, T defaultTensorValue) { int dimArray[] = {1, 1001}; std::vector labels(1001); std::vector outputVec(1001, defaultTensorValue); TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor(outputVec.data(), dims, 1, 0); TfLiteTensor* outputTensor = &tfTensor; std::vector resultVec; for (auto& selectedResult : selectedResults) { outputVec[selectedResult.first] = selectedResult.second; } arm::app::Classifier classifier; REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5, true)); REQUIRE(5 == resultVec.size()); for (size_t i = 0; i < resultVec.size(); ++i) { REQUIRE(resultVec[i].m_labelIdx == selectedResults[i].first); } } TEST_CASE("Common classifier") { SECTION("Test invalid classifier") { /* Note: Errors or warnings generated by this test will appear in output of any subsequent * failing tests causing misleading output. Give warning until solution is found */ printf("Invalid classifier common test output:\n"); TfLiteTensor* outputTens = nullptr; std::vector resultVec; arm::app::Classifier classifier; REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5, true)); printf("End of invalid classifier common test output. \nERROR messages above this line are " "expected and can be ignored.\n\n"); } SECTION("Test classification results") { SECTION("uint8") { /* Set the top five results . */ std::vector> selectedResults { {1000, 10}, {15, 9}, {0, 8}, {20, 7}, {10, 7} }; test_classifier_result(selectedResults, static_cast(5)); } SECTION("int8") { /* Set the top five results . */ std::vector> selectedResults { {1000, 10}, {15, 9}, {0, 8}, {20, -7}, {10, -7} }; test_classifier_result(selectedResults, static_cast(-100)); } SECTION("float") { /* Set the top five results . */ std::vector> selectedResults { {1000, 10.9f}, {15, 9.8f}, {0, 8.7f}, {20, -7.0f}, {10, -7.1f} }; test_classifier_result(selectedResults, -100.0f); } } }