aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkQuantizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/NetworkQuantizer.cpp')
-rw-r--r--src/armnn/NetworkQuantizer.cpp23
1 files changed, 18 insertions, 5 deletions
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp
index bf5c9ef0f2..f577aea00e 100644
--- a/src/armnn/NetworkQuantizer.cpp
+++ b/src/armnn/NetworkQuantizer.cpp
@@ -24,14 +24,14 @@
namespace armnn
{
-INetworkQuantizer* INetworkQuantizer::CreateRaw(INetwork* inputNetwork)
+INetworkQuantizer* INetworkQuantizer::CreateRaw(INetwork* inputNetwork, const QuantizerOptions& options)
{
- return new NetworkQuantizer(inputNetwork);
+ return new NetworkQuantizer(inputNetwork, options);
}
-INetworkQuantizerPtr INetworkQuantizer::Create(INetwork* inputNetwork)
+INetworkQuantizerPtr INetworkQuantizer::Create(INetwork* inputNetwork, const QuantizerOptions& options)
{
- return INetworkQuantizerPtr(CreateRaw(inputNetwork), &INetworkQuantizer::Destroy);
+ return INetworkQuantizerPtr(CreateRaw(inputNetwork, options), &INetworkQuantizer::Destroy);
}
void INetworkQuantizer::Destroy(INetworkQuantizer *quantizer)
@@ -58,7 +58,20 @@ INetworkPtr NetworkQuantizer::ExportNetwork()
VisitLayers(graph, rangeVisitor);
// Step 2) Convert input InputNetwork to Quantized InputNetwork
- QuantizerVisitor quantizerVisitor(m_Ranges);
+ std::unique_ptr<IQuantizationScheme> quantizationScheme;
+ switch (m_Options.m_ActivationFormat)
+ {
+ case DataType::QuantisedAsymm8:
+ quantizationScheme = std::make_unique<QAsymm8QuantizationScheme>();
+ break;
+ case DataType::QuantisedSymm16:
+ quantizationScheme = std::make_unique<QSymm16QuantizationScheme>();
+ break;
+ default:
+ throw InvalidArgumentException("Unsupported quantization target");
+ }
+
+ QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get());
VisitLayers(graph, quantizerVisitor);
return quantizerVisitor.RetrieveFinalNetwork();