diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 102 |
1 files changed, 13 insertions, 89 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 4e075149aa..1212716f97 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -6,92 +6,16 @@ #include "Network.hpp" #include "QuantizerVisitor.hpp" #include "StaticRangeVisitor.hpp" - -#include "armnn/TypesUtils.hpp" - -#include <cmath> -#include <stdint.h> -#include <limits> +#include "NetworkQuantizerUtils.hpp" namespace armnn { -namespace { - -std::pair<int, float> ComputeQAsymmParams(int numBits, double min, double max) -{ - BOOST_ASSERT_MSG(min < max, "Min >= max will result in invalid quantization."); - double highest = (1 << numBits)-1; - - min = std::min(0.0, min); // min <= 0.0 - max = std::max(0.0, max); // max >= 0.0 - - // assumes quantization range [0-highest] - double scale = (max-min) / highest; - double offset = -min / scale; - - // clamp offset [0-highest] - offset = std::max(0.0, std::min(highest, offset)); - - return std::make_pair(static_cast<int>(std::round(offset)), static_cast<float>(scale)); -} - -template<typename srcType> -void Quantize(const srcType* src, uint8_t* dst, size_t numElements, float &scale, int &offset) -{ - BOOST_ASSERT(src); - BOOST_ASSERT(dst); - - float min = std::numeric_limits<srcType>::max(); - float max = std::numeric_limits<srcType>::lowest(); - for (size_t i = 0; i < numElements; ++i) - { - min = std::min(min, src[i]); - max = std::max(max, src[i]); - } - - auto qParams = ComputeQAsymmParams(8, min, max); - offset = qParams.first; - scale = qParams.second; - for (size_t i = 0; i < numElements; ++i) - { - dst[i] = armnn::Quantize<uint8_t>(src[i], scale, offset); - } -} - -ConstTensor CreateQuantizedConst(const ConstTensor& tensor, std::vector<uint8_t> &backing) -{ - float scale = 0.0f; - int offset = 0; - // Reserve the backing memory - backing.resize(tensor.GetInfo().GetNumElements()); - - DataType type = tensor.GetInfo().GetDataType(); - switch(type) - { - case DataType::Float32: - { - Quantize(static_cast<const float*>( tensor.GetMemoryArea()), - backing.data(), - backing.size(), - scale, - offset); - } - break; - default: - BOOST_ASSERT_MSG(false, "Can't quantize unsupported data type"); - } - - TensorInfo qInfo(tensor.GetInfo().GetShape(), DataType::QuantisedAsymm8, scale, offset); - return ConstTensor(qInfo, backing); -} - -} // namespace - -QuantizerVisitor::QuantizerVisitor(armnn::StaticRangeVisitor* ranges) -: m_Ranges(ranges) -, m_QuantizedNetwork(INetwork::Create()) +QuantizerVisitor::QuantizerVisitor(const StaticRangeVisitor *staticRangeVisitor) + : m_StaticRangeVisitor(staticRangeVisitor) + , m_QuantizedNetwork(INetwork::Create()) { + BOOST_ASSERT(m_StaticRangeVisitor); } void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *srcLayer, @@ -106,17 +30,17 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src unsigned int slotIdx = outputSlot->CalculateIndexOnOwner(); Layer& layerToFind = outputSlot->GetOwningLayer(); - auto found = m_OldToNewGuidMap.find(layerToFind.GetGuid()); - if (found != m_OldToNewGuidMap.end()) + auto found = m_OriginalToQuantizedGuidMap.find(layerToFind.GetGuid()); + if (found != m_OriginalToQuantizedGuidMap.end()) { // Connect the slots in the quantized model - IConnectableLayer* prevQuantizedLayer = m_GuidToLayerMap[found->second]; + IConnectableLayer* prevQuantizedLayer = m_QuantizedGuidToLayerMap[found->second]; IInputSlot& newInputSlot = quantizedLayer->GetInputSlot(i); IOutputSlot& newOutputSlot = prevQuantizedLayer->GetOutputSlot(slotIdx); newOutputSlot.Connect(newInputSlot); // Fetch the min/max ranges that were computed earlier - auto range = m_Ranges->GetRange(layerToFind.GetGuid(), i); + auto range = m_StaticRangeVisitor->GetRange(layerToFind.GetGuid(), i); auto qParams = ComputeQAsymmParams(8, range.first, range.second); // Set the quantization params @@ -128,7 +52,7 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src } else { - // error in graph traversal order + // Error in graph traversal order BOOST_ASSERT_MSG(false, "Error in graph traversal"); } } @@ -136,8 +60,8 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src void QuantizerVisitor::RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer) { - m_OldToNewGuidMap[srcLayer->GetGuid()] = quantizedLayer->GetGuid(); - m_GuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer; + m_OriginalToQuantizedGuidMap[srcLayer->GetGuid()] = quantizedLayer->GetGuid(); + m_QuantizedGuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer; } void QuantizerVisitor::VisitAdditionLayer(const IConnectableLayer *layer, const char *name) @@ -200,4 +124,4 @@ void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *lay SetQuantizedInputConnections(layer, newLayer); } -} //namespace armnn
\ No newline at end of file +} //namespace armnn |