// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include #include #include #include #include #include #include #include namespace armnnUtils { using namespace armnn; class ModelAccuracyChecker { public: ModelAccuracyChecker(const std::map& validationLabelSet); float GetAccuracy(unsigned int k); template void AddImageResult(const std::string& imageName, std::vector outputTensor) { // Increment the total number of images processed ++m_ImagesProcessed; std::map confidenceMap; auto & output = outputTensor[0]; // Create a map of all predictions boost::apply_visitor([&](auto && value) { int index = 0; for (const auto & o : value) { if (o > 0) { confidenceMap.insert(std::pair(index, static_cast(o))); } ++index; } }, output); // Create a comparator for sorting the map in order of highest probability typedef std::function, std::pair)> Comparator; Comparator compFunctor = [](std::pair element1, std::pair element2) { return element1.second > element2.second; }; // Do the sorting and store in an ordered set std::set, Comparator> setOfPredictions( confidenceMap.begin(), confidenceMap.end(), compFunctor); std::string trimmedName = GetTrimmedImageName(imageName); int value = m_GroundTruthLabelSet.find(trimmedName)->second; unsigned int index = 1; for (std::pair element : setOfPredictions) { if(element.first == value) { ++m_TopK[index]; } else { ++index; } } } std::string GetTrimmedImageName(const std::string& imageName) const { std::string trimmedName; size_t lastindex = imageName.find_last_of("."); if(lastindex != std::string::npos) { trimmedName = imageName.substr(0, lastindex); } else { trimmedName = imageName; } return trimmedName; } private: const std::map m_GroundTruthLabelSet; std::vector m_TopK = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; unsigned int m_ImagesProcessed = 0; }; } //namespace armnnUtils