diff options
Diffstat (limited to 'src/armnn/NetworkQuantizer.cpp')
-rw-r--r-- | src/armnn/NetworkQuantizer.cpp | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp index f8e5ed23a7..ccbc501618 100644 --- a/src/armnn/NetworkQuantizer.cpp +++ b/src/armnn/NetworkQuantizer.cpp @@ -12,11 +12,12 @@ #include "Layer.hpp" #include "Network.hpp" #include "NetworkQuantizer.hpp" +#include "NetworkQuantizerUtils.hpp" #include "StaticRangeVisitor.hpp" #include "QuantizerVisitor.hpp" +#include "OverrideInputRangeVisitor.hpp" -#include <map> #include <vector> #include <cmath> @@ -38,26 +39,29 @@ void INetworkQuantizer::Destroy(INetworkQuantizer *quantizer) delete boost::polymorphic_downcast<NetworkQuantizer*>(quantizer); } +void NetworkQuantizer::OverrideInputRange(LayerBindingId layerId, float min, float max) +{ + const Graph& graph = boost::polymorphic_downcast<const Network*>(m_InputNetwork)->GetGraph(); + auto inputLayers = graph.GetInputLayers(); + + // Walk the input layers of the graph and override the quantization parameters of the one with the given id + OverrideInputRangeVisitor overrideInputRangeVisitor(m_GuidToRangesMap, layerId, MinMaxRange{min, max}); + VisitLayers(inputLayers, overrideInputRangeVisitor); +} + INetworkPtr NetworkQuantizer::ExportNetwork() { const Graph& graph = boost::polymorphic_downcast<const Network*>(m_InputNetwork)->GetGraph().TopologicalSort(); - auto VisitLayers = [&graph](ILayerVisitor& visitor) - { - for (auto layer : graph) - { - layer->Accept(visitor); - } - }; // Step 1) Walk the graph and register min/max values for intermediate tensors - StaticRangeVisitor rangeVisitor; - VisitLayers(rangeVisitor); + StaticRangeVisitor rangeVisitor(m_GuidToRangesMap); + VisitLayers(graph, rangeVisitor); // Step 2) Convert input InputNetwork to Quantized InputNetwork QuantizerVisitor quantizerVisitor(&rangeVisitor); - VisitLayers(quantizerVisitor); + VisitLayers(graph, quantizerVisitor); return quantizerVisitor.RetrieveFinalNetwork(); } -} //namespace armn
\ No newline at end of file +} //namespace armn |