aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-05-09 10:13:20 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-10 08:30:55 +0000
commit724e48013142562b7f09c9c819f57c314c4ee3d4 (patch)
tree935b425961ccb13a488708312bcc70b3b32fc87e /src/armnn/QuantizerVisitor.cpp
parent5fa83938592db420914903235daf3f1d5c97d6bc (diff)
downloadarmnn-724e48013142562b7f09c9c819f57c314c4ee3d4.tar.gz
IVGCVSW-3061 Modify NetworkQuantizer to support option to preserve input/output types
* Also add unit tests for new preserve type option Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> Change-Id: I860759072f2e3546698118d1bcd5e79eb4e805ec
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp41
1 files changed, 35 insertions, 6 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 919eda1c7e..61e0e60779 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -11,10 +11,13 @@
namespace armnn
{
-QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme)
+QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker,
+ const IQuantizationScheme* quantizationScheme,
+ bool preserveType)
: m_Ranges(rangeTracker)
, m_QuantizedNetwork(INetwork::Create())
, m_QuantizationScheme(quantizationScheme)
+ , m_PreserveType(preserveType)
{
}
@@ -106,15 +109,41 @@ void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer,
void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name);
- RecordLayer(layer, newLayer);
+ const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType();
+ IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer();
+ inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo());
+ RecordLayer(layer, quantizeLayer);
+ }
+ else
+ {
+ RecordLayer(layer, inputLayer);
+ }
}
void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
- RecordLayer(layer, newLayer);
- SetQuantizedInputConnections(layer, newLayer);
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const DataType& dataType = info.GetDataType();
+ IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer();
+ RecordLayer(layer, dequantizeLayer);
+ SetQuantizedInputConnections(layer, dequantizeLayer);
+ dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info);
+ }
+ else
+ {
+ RecordLayer(layer, outputLayer);
+ SetQuantizedInputConnections(layer, outputLayer);
+ }
}
void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer,