aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkQuantizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/NetworkQuantizer.cpp')
-rw-r--r--src/armnn/NetworkQuantizer.cpp30
1 files changed, 15 insertions, 15 deletions
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp
index e6becee96f..eed3f41bdc 100644
--- a/src/armnn/NetworkQuantizer.cpp
+++ b/src/armnn/NetworkQuantizer.cpp
@@ -8,9 +8,9 @@
#include "Graph.hpp"
#include "Layer.hpp"
#include "Network.hpp"
-#include "DynamicQuantizationVisitor.hpp"
-#include "StaticRangeVisitor.hpp"
-#include "QuantizerVisitor.hpp"
+#include "DynamicQuantizationStrategy.hpp"
+#include "StaticRangeStrategy.hpp"
+#include "QuantizerStrategy.hpp"
#include "OverrideInputRangeVisitor.hpp"
#include <TensorIOUtils.hpp>
@@ -60,9 +60,9 @@ void NetworkQuantizer::OverrideInputRange(LayerBindingId layerId, float min, flo
void NetworkQuantizer::Refine(const InputTensors& inputTensors)
{
- // The first time Refine is called the m_Runtime and the DynamicQuantizationVisitor
+ // The first time Refine is called the m_Runtime and the DynamicQuantizationStrategy
// will not have been created. Need to get the environment set up, Runtime loaded,
- // DynamicQuantizationVisitor created and run over the network to initialise itself
+ // DynamicQuantizationStrategy created and run over the network to initialise itself
// and the RangeTracker the Debug callback registered and an initial inference
// done to set up the first min/max values
if (!m_Runtime)
@@ -71,15 +71,15 @@ void NetworkQuantizer::Refine(const InputTensors& inputTensors)
m_Ranges.SetDynamicMode(true);
const Graph& cGraph = PolymorphicDowncast<const Network*>(m_InputNetwork)->GetGraph().TopologicalSort();
- // need to insert Debug layers in the DynamicQuantizationVisitor
+ // need to insert Debug layers in the DynamicQuantizationStrategy
Graph& graph = const_cast<Graph&>(cGraph);
// Initialize RangeTracker to the default values for each layer.
// The default values are overwritten by the min/max that is
// recorded during the first dataset min/max calibration. This
// initialisation is only required for the first call of Refine().
- m_DynamicQuantizationVisitor = DynamicQuantizationVisitor(m_Ranges, graph);
- VisitLayers(cGraph, m_DynamicQuantizationVisitor.value());
+ m_DynamicQuantizationStrategy = DynamicQuantizationStrategy(m_Ranges, graph);
+ ApplyStrategyToLayers(cGraph, m_DynamicQuantizationStrategy.value());
IRuntime::CreationOptions options;
m_Runtime = IRuntime::Create(options);
@@ -119,7 +119,7 @@ void NetworkQuantizer::Refine(const InputTensors& inputTensors)
// Create output tensor for EnqueueWorkload
std::vector<armnn::BindingPointInfo> outputBindings;
- auto outputLayers = m_DynamicQuantizationVisitor.value().GetOutputLayers();
+ auto outputLayers = m_DynamicQuantizationStrategy.value().GetOutputLayers();
std::vector<TContainer> outputVectors;
for (auto outputLayerBindingId : outputLayers)
{
@@ -144,16 +144,16 @@ INetworkPtr NetworkQuantizer::ExportNetwork()
if (!m_Runtime)
{
m_Ranges.SetDynamicMode(false);
- StaticRangeVisitor rangeVisitor(m_Ranges);
- VisitLayers(graph, rangeVisitor);
+ StaticRangeStrategy rangeStrategy(m_Ranges);
+ ApplyStrategyToLayers(graph, rangeStrategy);
}
else
{
// Set min/max range of non-calibrated layers to parent layer's range
- m_DynamicQuantizationVisitor.value().VisitNonCalibratedLayers();
+ m_DynamicQuantizationStrategy.value().VisitNonCalibratedLayers();
// now tear down the runtime and the dynamic visitor.
m_Runtime.reset(nullptr);
- m_DynamicQuantizationVisitor = EmptyOptional();
+ m_DynamicQuantizationStrategy = EmptyOptional();
m_RefineCount = 0;
}
@@ -177,8 +177,8 @@ INetworkPtr NetworkQuantizer::ExportNetwork()
throw InvalidArgumentException("Unsupported quantization target");
}
- QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType);
- VisitLayers(graph, quantizerVisitor);
+ QuantizerStrategy quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType);
+ ApplyStrategyToLayers(graph, quantizerVisitor);
// clear the ranges
m_Ranges.Reset();