diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 61 |
1 files changed, 30 insertions, 31 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 1212716f97..b5085be0a2 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -11,17 +11,17 @@ namespace armnn { -QuantizerVisitor::QuantizerVisitor(const StaticRangeVisitor *staticRangeVisitor) +QuantizerVisitor::QuantizerVisitor(const StaticRangeVisitor* staticRangeVisitor) : m_StaticRangeVisitor(staticRangeVisitor) , m_QuantizedNetwork(INetwork::Create()) { BOOST_ASSERT(m_StaticRangeVisitor); } -void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *srcLayer, - IConnectableLayer *quantizedLayer) +void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer* srcLayer, + IConnectableLayer* quantizedLayer) { - for (unsigned int i=0; i < srcLayer->GetNumInputSlots(); i++) + for (unsigned int i = 0; i < srcLayer->GetNumInputSlots(); i++) { const IInputSlot& srcInputSlot = srcLayer->GetInputSlot(i); const InputSlot* inputSlot = boost::polymorphic_downcast<const InputSlot*>(&srcInputSlot); @@ -31,30 +31,29 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src Layer& layerToFind = outputSlot->GetOwningLayer(); auto found = m_OriginalToQuantizedGuidMap.find(layerToFind.GetGuid()); - if (found != m_OriginalToQuantizedGuidMap.end()) - { - // Connect the slots in the quantized model - IConnectableLayer* prevQuantizedLayer = m_QuantizedGuidToLayerMap[found->second]; - IInputSlot& newInputSlot = quantizedLayer->GetInputSlot(i); - IOutputSlot& newOutputSlot = prevQuantizedLayer->GetOutputSlot(slotIdx); - newOutputSlot.Connect(newInputSlot); - - // Fetch the min/max ranges that were computed earlier - auto range = m_StaticRangeVisitor->GetRange(layerToFind.GetGuid(), i); - auto qParams = ComputeQAsymmParams(8, range.first, range.second); - - // Set the quantization params - TensorInfo info(newOutputSlot.GetTensorInfo()); - info.SetDataType(DataType::QuantisedAsymm8); - info.SetQuantizationOffset(qParams.first); - info.SetQuantizationScale(qParams.second); - newOutputSlot.SetTensorInfo(info); - } - else + if (found == m_OriginalToQuantizedGuidMap.end()) { // Error in graph traversal order BOOST_ASSERT_MSG(false, "Error in graph traversal"); + return; } + + // Connect the slots in the quantized model + IConnectableLayer* prevQuantizedLayer = m_QuantizedGuidToLayerMap[found->second]; + IInputSlot& newInputSlot = quantizedLayer->GetInputSlot(i); + IOutputSlot& newOutputSlot = prevQuantizedLayer->GetOutputSlot(slotIdx); + newOutputSlot.Connect(newInputSlot); + + // Fetch the min/max ranges that were computed earlier + auto range = m_StaticRangeVisitor->GetRange(layerToFind.GetGuid(), i); + auto qParams = ComputeQAsymmParams(8, range.first, range.second); + + // Set the quantization params + TensorInfo info(newOutputSlot.GetTensorInfo()); + info.SetDataType(DataType::QuantisedAsymm8); + info.SetQuantizationOffset(qParams.first); + info.SetQuantizationScale(qParams.second); + newOutputSlot.SetTensorInfo(info); } } @@ -64,42 +63,42 @@ void QuantizerVisitor::RecordLayer(const IConnectableLayer* srcLayer, IConnectab m_QuantizedGuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer; } -void QuantizerVisitor::VisitAdditionLayer(const IConnectableLayer *layer, const char *name) +void QuantizerVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const char* name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddAdditionLayer(name); RecordLayer(layer, newLayer); SetQuantizedInputConnections(layer, newLayer); } -void QuantizerVisitor::VisitActivationLayer(const IConnectableLayer *layer, +void QuantizerVisitor::VisitActivationLayer(const IConnectableLayer* layer, const ActivationDescriptor& activationDescriptor, - const char *name) + const char* name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddActivationLayer(activationDescriptor, name); RecordLayer(layer, newLayer); SetQuantizedInputConnections(layer, newLayer); } -void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name) +void QuantizerVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name); RecordLayer(layer, newLayer); } -void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name) +void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name); RecordLayer(layer, newLayer); SetQuantizedInputConnections(layer, newLayer); } -void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *layer, +void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer, const BatchNormalizationDescriptor& desc, const ConstTensor& mean, const ConstTensor& variance, const ConstTensor& beta, const ConstTensor& gamma, - const char *name) + const char* name) { std::vector<uint8_t> meanBacking; ConstTensor qMean = CreateQuantizedConst(mean, meanBacking); |