aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkQuantizer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/NetworkQuantizer.hpp')
-rw-r--r--src/armnn/NetworkQuantizer.hpp27
1 files changed, 25 insertions, 2 deletions
diff --git a/src/armnn/NetworkQuantizer.hpp b/src/armnn/NetworkQuantizer.hpp
index 4f6359f36d..d384bdc545 100644
--- a/src/armnn/NetworkQuantizer.hpp
+++ b/src/armnn/NetworkQuantizer.hpp
@@ -6,9 +6,12 @@
#pragma once
#include <armnn/INetwork.hpp>
-#include <armnn/INetworkQuantizer.hpp>
+#include <armnnQuantizer/INetworkQuantizer.hpp>
+#include <armnn/IRuntime.hpp>
#include <armnn/Types.hpp>
+#include <armnn/Optional.hpp>
+#include "DynamicQuantizationVisitor.hpp"
#include "RangeTracker.hpp"
namespace armnn
@@ -18,21 +21,41 @@ class NetworkQuantizer : public INetworkQuantizer
{
public:
NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options)
- : m_InputNetwork(inputNetwork), m_Options(options) {}
+ : m_InputNetwork(inputNetwork),
+ m_NetworkId(0),
+ m_Runtime(nullptr, &IRuntime::Destroy),
+ m_RefineCount(0),
+ m_Options(options) {}
void OverrideInputRange(LayerBindingId layerId, float min, float max) override;
void Refine(const InputTensors& inputTensors) override;
+
+ // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges)
+ std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); }
INetworkPtr ExportNetwork() override;
private:
/// Original input network to quantize
INetwork* m_InputNetwork;
+ NetworkId m_NetworkId;
+
+ // if we are run in dynamic mode this unique pointer will hold
+ // the runtime between invocations of the Refine method.
+ IRuntimePtr m_Runtime;
+
+ Optional<DynamicQuantizationVisitor> m_DynamicQuantizationVisitor;
+
+ // counts the number of times refine is called
+ unsigned int m_RefineCount;
+
/// Mapping from Guid to an array of ranges for outputs
RangeTracker m_Ranges;
/// Options for the NetworkQuantizer
QuantizerOptions m_Options;
+
+ std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle);
};
} //namespace armnn