aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp41
1 files changed, 35 insertions, 6 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 919eda1c7e..61e0e60779 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -11,10 +11,13 @@
namespace armnn
{
-QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme)
+QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker,
+ const IQuantizationScheme* quantizationScheme,
+ bool preserveType)
: m_Ranges(rangeTracker)
, m_QuantizedNetwork(INetwork::Create())
, m_QuantizationScheme(quantizationScheme)
+ , m_PreserveType(preserveType)
{
}
@@ -106,15 +109,41 @@ void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer,
void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name);
- RecordLayer(layer, newLayer);
+ const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType();
+ IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer();
+ inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo());
+ RecordLayer(layer, quantizeLayer);
+ }
+ else
+ {
+ RecordLayer(layer, inputLayer);
+ }
}
void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
- RecordLayer(layer, newLayer);
- SetQuantizedInputConnections(layer, newLayer);
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const DataType& dataType = info.GetDataType();
+ IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer();
+ RecordLayer(layer, dequantizeLayer);
+ SetQuantizedInputConnections(layer, dequantizeLayer);
+ dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info);
+ }
+ else
+ {
+ RecordLayer(layer, outputLayer);
+ SetQuantizedInputConnections(layer, outputLayer);
+ }
}
void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer,