aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/DynamicQuantizationStrategy.hpp
blob: aa77a4b56300187b0fda4d60ac808f965997bbdc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "armnn/LayerVisitorBase.hpp"
#include "RangeTracker.hpp"
#include "layers/DebugLayer.hpp"

#include <armnn/INetwork.hpp>
#include <armnnQuantizer/INetworkQuantizer.hpp>

namespace armnn
{

/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
class DynamicQuantizationStrategy : public armnn::IStrategy
{
public:

    DynamicQuantizationStrategy(RangeTracker& rangeTracker, Graph& graph);
    ~DynamicQuantizationStrategy() = default;

    virtual void ExecuteStrategy(const armnn::IConnectableLayer* layer,
                                 const armnn::BaseDescriptor& descriptor,
                                 const std::vector<armnn::ConstTensor>& constants,
                                 const char* name,
                                 const armnn::LayerBindingId id = 0) override;

    const std::vector<armnn::LayerBindingId>& GetOutputLayers();
    void VisitNonCalibratedLayers();
    void FinishStrategy() override;


private:
    /// Set the range for an output slot on a layer
    void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max);

    void ForwardParentParameters(const IConnectableLayer* layer);

    /// Mapping from a layer Guid to an array of ranges for outputs
    RangeTracker& m_RangeTracker;

    Graph& m_Graph;

    std::vector<const IConnectableLayer*> m_LayersToCalibrate;
    std::vector<const IConnectableLayer*> m_LayersNotToCalibrate;
    std::vector<DebugLayer*> m_DebugLayers;

    std::vector<armnn::LayerBindingId> m_OutputLayers;
    void AddToCalibratedLayers(const IConnectableLayer* layer);
    void AddToNonCalibratedLayers(const IConnectableLayer* layer);
    void RemoveDebugLayers();


};
} //namespace armnn