diff options
author | Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> | 2019-05-09 10:13:20 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-10 08:30:55 +0000 |
commit | 724e48013142562b7f09c9c819f57c314c4ee3d4 (patch) | |
tree | 935b425961ccb13a488708312bcc70b3b32fc87e /src/armnn/QuantizerVisitor.cpp | |
parent | 5fa83938592db420914903235daf3f1d5c97d6bc (diff) | |
download | armnn-724e48013142562b7f09c9c819f57c314c4ee3d4.tar.gz |
IVGCVSW-3061 Modify NetworkQuantizer to support option to preserve input/output types
* Also add unit tests for new preserve type option
Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Change-Id: I860759072f2e3546698118d1bcd5e79eb4e805ec
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 919eda1c7e..61e0e60779 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -11,10 +11,13 @@ namespace armnn { -QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme) +QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, + const IQuantizationScheme* quantizationScheme, + bool preserveType) : m_Ranges(rangeTracker) , m_QuantizedNetwork(INetwork::Create()) , m_QuantizationScheme(quantizationScheme) + , m_PreserveType(preserveType) { } @@ -106,15 +109,41 @@ void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer, void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name) { - IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name); - RecordLayer(layer, newLayer); + const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType(); + IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name); + + if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16)) + { + IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer(); + inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo()); + RecordLayer(layer, quantizeLayer); + } + else + { + RecordLayer(layer, inputLayer); + } } void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) { - IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name); - RecordLayer(layer, newLayer); - SetQuantizedInputConnections(layer, newLayer); + const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + const DataType& dataType = info.GetDataType(); + IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name); + + if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16)) + { + IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer(); + RecordLayer(layer, dequantizeLayer); + SetQuantizedInputConnections(layer, dequantizeLayer); + dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info); + } + else + { + RecordLayer(layer, outputLayer); + SetQuantizedInputConnections(layer, outputLayer); + } } void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer, |