aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkQuantizer.hpp
blob: a07ac8827ecd1c96d5cdae7f64abcb69ea15cb02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/INetwork.hpp>
#include <armnnQuantizer/INetworkQuantizer.hpp>
#include <armnn/IRuntime.hpp>
#include <armnn/Types.hpp>
#include <armnn/Optional.hpp>

#include "DynamicQuantizationStrategy.hpp"
#include "RangeTracker.hpp"

namespace armnn
{

class NetworkQuantizer : public INetworkQuantizer
{
public:
    NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options)
    : m_InputNetwork(inputNetwork),
      m_NetworkId(0),
      m_Runtime(nullptr, &IRuntime::Destroy),
      m_RefineCount(0),
      m_Options(options) {}

    void OverrideInputRange(LayerBindingId layerId, float min, float max) override;
    void Refine(const InputTensors& inputTensors) override;

    // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges)
    std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); }
    INetworkPtr ExportNetwork() override;

private:
    /// Original input network to quantize
    INetwork* m_InputNetwork;

    NetworkId m_NetworkId;

    // if we are run in dynamic mode this unique pointer will hold
    // the runtime between invocations of the Refine method.
    IRuntimePtr m_Runtime;

    Optional<DynamicQuantizationStrategy> m_DynamicQuantizationStrategy;

    // counts the number of times refine is called
    unsigned int m_RefineCount;

    /// Mapping from Guid to an array of ranges for outputs
    RangeTracker m_Ranges;

    /// Options for the NetworkQuantizer
    QuantizerOptions m_Options;

    std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle);
};

} //namespace armnn