aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ModelAccuracyChecker.cpp
blob: bee5ca2365979ab620a1560388d4b55f79d13a3c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <vector>
#include <map>
#include <boost/log/trivial.hpp>
#include "ModelAccuracyChecker.hpp"

namespace armnnUtils
{

armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, int>& validationLabels)
    : m_GroundTruthLabelSet(validationLabels){}

float ModelAccuracyChecker::GetAccuracy(unsigned int k)
{
    if(k > 10) {
        BOOST_LOG_TRIVIAL(info) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
                                   "Printing Top 10 Accuracy result!";
        k = 10;
    }
    unsigned int total = 0;
    for (unsigned int i = k; i > 0; --i)
    {
        total += m_TopK[i];
    }
    return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
}
}