diff options
Diffstat (limited to 'src/armnn/NetworkQuantizer.cpp')
-rw-r--r-- | src/armnn/NetworkQuantizer.cpp | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp index bf5c9ef0f2..f577aea00e 100644 --- a/src/armnn/NetworkQuantizer.cpp +++ b/src/armnn/NetworkQuantizer.cpp @@ -24,14 +24,14 @@ namespace armnn { -INetworkQuantizer* INetworkQuantizer::CreateRaw(INetwork* inputNetwork) +INetworkQuantizer* INetworkQuantizer::CreateRaw(INetwork* inputNetwork, const QuantizerOptions& options) { - return new NetworkQuantizer(inputNetwork); + return new NetworkQuantizer(inputNetwork, options); } -INetworkQuantizerPtr INetworkQuantizer::Create(INetwork* inputNetwork) +INetworkQuantizerPtr INetworkQuantizer::Create(INetwork* inputNetwork, const QuantizerOptions& options) { - return INetworkQuantizerPtr(CreateRaw(inputNetwork), &INetworkQuantizer::Destroy); + return INetworkQuantizerPtr(CreateRaw(inputNetwork, options), &INetworkQuantizer::Destroy); } void INetworkQuantizer::Destroy(INetworkQuantizer *quantizer) @@ -58,7 +58,20 @@ INetworkPtr NetworkQuantizer::ExportNetwork() VisitLayers(graph, rangeVisitor); // Step 2) Convert input InputNetwork to Quantized InputNetwork - QuantizerVisitor quantizerVisitor(m_Ranges); + std::unique_ptr<IQuantizationScheme> quantizationScheme; + switch (m_Options.m_ActivationFormat) + { + case DataType::QuantisedAsymm8: + quantizationScheme = std::make_unique<QAsymm8QuantizationScheme>(); + break; + case DataType::QuantisedSymm16: + quantizationScheme = std::make_unique<QSymm16QuantizationScheme>(); + break; + default: + throw InvalidArgumentException("Unsupported quantization target"); + } + + QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get()); VisitLayers(graph, quantizerVisitor); return quantizerVisitor.RetrieveFinalNetwork(); |