aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/RangeTracker.cpp
blob: ae756fbb9cfe86fbee1f14a6a0783f08772ce9b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "RangeTracker.hpp"
#include "InternalTypes.hpp"

namespace armnn
{

void RangeTracker::SetRange(const armnn::IConnectableLayer* layer, unsigned int outputIdx, float min, float max)
{
    auto& ranges = m_GuidToRangesMap[layer->GetGuid()];

    unsigned int numOfOutputSlots = layer->GetNumOutputSlots();
    // output layers are a special case
    if (numOfOutputSlots == 0)
    {
        ++numOfOutputSlots;
    }
    if (ranges.size() < numOfOutputSlots)
    {
        ranges.resize(numOfOutputSlots);
    }
    ranges[outputIdx] = std::make_pair(min, max);
}

RangeTracker::MinMaxRange RangeTracker::GetRange(armnn::LayerGuid guid, unsigned int idx) const
{
    auto search = m_GuidToRangesMap.find(guid);
    if (search == m_GuidToRangesMap.end())
    {
        if (IsInDynamicMode())
        {
            throw armnn::Exception("Have no entry for layer GUID [" + std::to_string(guid) + "]");
        }
        else
        {
            return DefaultRange();
        }
    }
    return search->second.at(idx);
}

void RangeTracker::RefineMin(LayerGuid guid, unsigned int idx, float newMin)
{
    auto& currentMin = m_GuidToRangesMap.find(guid)->second.at(idx).first;
    if (newMin < currentMin)
    {
        currentMin = newMin;
    }
}

void RangeTracker::RefineMax(LayerGuid guid, unsigned int idx, float newMax)
{
    auto& currentMax = m_GuidToRangesMap.find(guid)->second.at(idx).second;
    if (newMax > currentMax)
    {
        currentMax = newMax;
    }
}

void RangeTracker::ResetMinMax(LayerGuid guid, unsigned int idx, float newMin, float newMax)
{
    auto minMaxPair = m_GuidToRangesMap.find(guid);
    auto& currentMin = minMaxPair->second.at(idx).first;
    auto& currentMax = minMaxPair->second.at(idx).second;

    currentMin = newMin;
    currentMax = newMax;
}

void RangeTracker::Reset()
{
    m_GuidToRangesMap.clear();
}

} //namespace armnn