diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-02-09 15:56:23 +0000 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-02-12 13:10:20 +0000 |
commit | b454c5c65efb238c130b042ace390b2bc7f0bf75 (patch) | |
tree | d6681d0abf416b3cc280bc3bb70e7d55dfd40a0d /src/armnn/NetworkQuantizer.cpp | |
parent | 8eae955f665f371b0a2c7c1a06e8ba442afa2298 (diff) | |
download | armnn-b454c5c65efb238c130b042ace390b2bc7f0bf75.tar.gz |
IVGCVSW-4893 Refactor ILayerVisitor using unified interface strategy.
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: Id7bc8255a8e3f9e5aac65d510bec8a559bf37246
Diffstat (limited to 'src/armnn/NetworkQuantizer.cpp')
-rw-r--r-- | src/armnn/NetworkQuantizer.cpp | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp index e6becee96f..eed3f41bdc 100644 --- a/src/armnn/NetworkQuantizer.cpp +++ b/src/armnn/NetworkQuantizer.cpp @@ -8,9 +8,9 @@ #include "Graph.hpp" #include "Layer.hpp" #include "Network.hpp" -#include "DynamicQuantizationVisitor.hpp" -#include "StaticRangeVisitor.hpp" -#include "QuantizerVisitor.hpp" +#include "DynamicQuantizationStrategy.hpp" +#include "StaticRangeStrategy.hpp" +#include "QuantizerStrategy.hpp" #include "OverrideInputRangeVisitor.hpp" #include <TensorIOUtils.hpp> @@ -60,9 +60,9 @@ void NetworkQuantizer::OverrideInputRange(LayerBindingId layerId, float min, flo void NetworkQuantizer::Refine(const InputTensors& inputTensors) { - // The first time Refine is called the m_Runtime and the DynamicQuantizationVisitor + // The first time Refine is called the m_Runtime and the DynamicQuantizationStrategy // will not have been created. Need to get the environment set up, Runtime loaded, - // DynamicQuantizationVisitor created and run over the network to initialise itself + // DynamicQuantizationStrategy created and run over the network to initialise itself // and the RangeTracker the Debug callback registered and an initial inference // done to set up the first min/max values if (!m_Runtime) @@ -71,15 +71,15 @@ void NetworkQuantizer::Refine(const InputTensors& inputTensors) m_Ranges.SetDynamicMode(true); const Graph& cGraph = PolymorphicDowncast<const Network*>(m_InputNetwork)->GetGraph().TopologicalSort(); - // need to insert Debug layers in the DynamicQuantizationVisitor + // need to insert Debug layers in the DynamicQuantizationStrategy Graph& graph = const_cast<Graph&>(cGraph); // Initialize RangeTracker to the default values for each layer. // The default values are overwritten by the min/max that is // recorded during the first dataset min/max calibration. This // initialisation is only required for the first call of Refine(). - m_DynamicQuantizationVisitor = DynamicQuantizationVisitor(m_Ranges, graph); - VisitLayers(cGraph, m_DynamicQuantizationVisitor.value()); + m_DynamicQuantizationStrategy = DynamicQuantizationStrategy(m_Ranges, graph); + ApplyStrategyToLayers(cGraph, m_DynamicQuantizationStrategy.value()); IRuntime::CreationOptions options; m_Runtime = IRuntime::Create(options); @@ -119,7 +119,7 @@ void NetworkQuantizer::Refine(const InputTensors& inputTensors) // Create output tensor for EnqueueWorkload std::vector<armnn::BindingPointInfo> outputBindings; - auto outputLayers = m_DynamicQuantizationVisitor.value().GetOutputLayers(); + auto outputLayers = m_DynamicQuantizationStrategy.value().GetOutputLayers(); std::vector<TContainer> outputVectors; for (auto outputLayerBindingId : outputLayers) { @@ -144,16 +144,16 @@ INetworkPtr NetworkQuantizer::ExportNetwork() if (!m_Runtime) { m_Ranges.SetDynamicMode(false); - StaticRangeVisitor rangeVisitor(m_Ranges); - VisitLayers(graph, rangeVisitor); + StaticRangeStrategy rangeStrategy(m_Ranges); + ApplyStrategyToLayers(graph, rangeStrategy); } else { // Set min/max range of non-calibrated layers to parent layer's range - m_DynamicQuantizationVisitor.value().VisitNonCalibratedLayers(); + m_DynamicQuantizationStrategy.value().VisitNonCalibratedLayers(); // now tear down the runtime and the dynamic visitor. m_Runtime.reset(nullptr); - m_DynamicQuantizationVisitor = EmptyOptional(); + m_DynamicQuantizationStrategy = EmptyOptional(); m_RefineCount = 0; } @@ -177,8 +177,8 @@ INetworkPtr NetworkQuantizer::ExportNetwork() throw InvalidArgumentException("Unsupported quantization target"); } - QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType); - VisitLayers(graph, quantizerVisitor); + QuantizerStrategy quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType); + ApplyStrategyToLayers(graph, quantizerVisitor); // clear the ranges m_Ranges.Reset(); |