aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ModelAccuracyChecker.hpp
blob: c4dd4f1b0504cd3d9d698fb9418a13a1545c0d74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <algorithm>
#include <armnn/Types.hpp>
#include <boost/assert.hpp>
#include <boost/variant/apply_visitor.hpp>
#include <cstddef>
#include <functional>
#include <iostream>
#include <map>
#include <string>
#include <vector>

namespace armnnUtils
{

using namespace armnn;

// Category names associated with a label
using LabelCategoryNames = std::vector<std::string>;

/** Split a string into tokens by a delimiter
 *
 * @param[in] originalString    Original string to be split
 * @param[in] delimiter         Delimiter used to split \p originalString
 * @param[in] includeEmptyToekn If true, include empty tokens in the result
 * @return A vector of tokens split from \p originalString by \delimiter
 */
std::vector<std::string>
    SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false);

/** Remove any preceding and trailing character specified in the characterSet.
 *
 * @param[in] originalString    Original string to be stripped
 * @param[in] characterSet      Set of characters to be stripped from \p originalString
 * @return A string stripped of all characters specified in \p characterSet from \p originalString
 */
std::string Strip(const std::string& originalString, const std::string& characterSet = " ");

class ModelAccuracyChecker
{
public:
    /** Constructor for a model top k accuracy checker
     *
     * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their
                                     corresponding ground-truth labels.
     * @param[in] modelOutputLabels  Mapping from output nodes to the category names of their corresponding labels
                                     Note that an output node can have multiple category names.
     */
    ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet,
                         const std::vector<LabelCategoryNames>& modelOutputLabels);

    /** Get Top K accuracy
     *
     * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is
                    3, then a prediction is considered correct as long as the ground-truth appears in the top 3
                    predictions.
     * @return  The accuracy, according to the top \p k th predictions.
     */
    float GetAccuracy(unsigned int k);

    /** Record the prediction result of an image
     *
     * @param[in] imageName     Name of the image.
     * @param[in] outputTensor  Output tensor of the network running \p imageName.
     */
    template <typename TContainer>
    void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
    {
        // Increment the total number of images processed
        ++m_ImagesProcessed;

        std::map<int, float> confidenceMap;
        auto& output = outputTensor[0];

        // Create a map of all predictions
        boost::apply_visitor([&confidenceMap](auto && value)
                             {
                                 int index = 0;
                                 for (const auto & o : value)
                                 {
                                     if (o > 0)
                                     {
                                         confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
                                     }
                                     ++index;
                                 }
                             },
                             output);

        // Create a comparator for sorting the map in order of highest probability
        typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;

        Comparator compFunctor =
            [](std::pair<int, float> element1, std::pair<int, float> element2)
            {
                return element1.second > element2.second;
            };

        // Do the sorting and store in an ordered set
        std::set<std::pair<int, float>, Comparator> setOfPredictions(
            confidenceMap.begin(), confidenceMap.end(), compFunctor);

        const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);

        unsigned int index = 1;
        for (std::pair<int, float> element : setOfPredictions)
        {
            if (index >= m_TopK.size())
            {
                break;
            }
            // Check if the ground truth label value is included in the topi prediction.
            // Note that a prediction can have multiple prediction labels.
            const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
            if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
            {
                ++m_TopK[index];
                break;
            }
            ++index;
        }
    }

private:
    const std::map<std::string, std::string> m_GroundTruthLabelSet;
    const std::vector<LabelCategoryNames> m_ModelOutputLabels;
    std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
    unsigned int m_ImagesProcessed   = 0;
};
} //namespace armnnUtils