aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-02-07 17:51:09 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-02-08 12:23:05 +0000
commita8d572dc48f47e66cd7abd6ad9b2d3a0f40ea94b (patch)
tree5de7809a8fbc19d6d2a940a51a982bd633156945 /src/armnn/QuantizerVisitor.cpp
parente0a4ad8a8e6ef271883e8029985eeab16d838972 (diff)
downloadarmnn-a8d572dc48f47e66cd7abd6ad9b2d3a0f40ea94b.tar.gz
IVGCVSW-2607 Implement Input range override mechanism
* Added the OverrideInputRange method to the Quantizer API * Created OverrideInputRangeVisitor to implement the override mechanism * Moved the quantizer utility functions to the new NetworkQuantizerUtils files * Moved the map of quantization ranges out of the StaticRangeVisitor and into the NetworkQuantizer * Added unit tests * Code refactoring and cleanup Change-Id: I9c1d006c1b6a35fbc04584a832fbe489f8f9276d Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp102
1 files changed, 13 insertions, 89 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 4e075149aa..1212716f97 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -6,92 +6,16 @@
#include "Network.hpp"
#include "QuantizerVisitor.hpp"
#include "StaticRangeVisitor.hpp"
-
-#include "armnn/TypesUtils.hpp"
-
-#include <cmath>
-#include <stdint.h>
-#include <limits>
+#include "NetworkQuantizerUtils.hpp"
namespace armnn
{
-namespace {
-
-std::pair<int, float> ComputeQAsymmParams(int numBits, double min, double max)
-{
- BOOST_ASSERT_MSG(min < max, "Min >= max will result in invalid quantization.");
- double highest = (1 << numBits)-1;
-
- min = std::min(0.0, min); // min <= 0.0
- max = std::max(0.0, max); // max >= 0.0
-
- // assumes quantization range [0-highest]
- double scale = (max-min) / highest;
- double offset = -min / scale;
-
- // clamp offset [0-highest]
- offset = std::max(0.0, std::min(highest, offset));
-
- return std::make_pair(static_cast<int>(std::round(offset)), static_cast<float>(scale));
-}
-
-template<typename srcType>
-void Quantize(const srcType* src, uint8_t* dst, size_t numElements, float &scale, int &offset)
-{
- BOOST_ASSERT(src);
- BOOST_ASSERT(dst);
-
- float min = std::numeric_limits<srcType>::max();
- float max = std::numeric_limits<srcType>::lowest();
- for (size_t i = 0; i < numElements; ++i)
- {
- min = std::min(min, src[i]);
- max = std::max(max, src[i]);
- }
-
- auto qParams = ComputeQAsymmParams(8, min, max);
- offset = qParams.first;
- scale = qParams.second;
- for (size_t i = 0; i < numElements; ++i)
- {
- dst[i] = armnn::Quantize<uint8_t>(src[i], scale, offset);
- }
-}
-
-ConstTensor CreateQuantizedConst(const ConstTensor& tensor, std::vector<uint8_t> &backing)
-{
- float scale = 0.0f;
- int offset = 0;
- // Reserve the backing memory
- backing.resize(tensor.GetInfo().GetNumElements());
-
- DataType type = tensor.GetInfo().GetDataType();
- switch(type)
- {
- case DataType::Float32:
- {
- Quantize(static_cast<const float*>( tensor.GetMemoryArea()),
- backing.data(),
- backing.size(),
- scale,
- offset);
- }
- break;
- default:
- BOOST_ASSERT_MSG(false, "Can't quantize unsupported data type");
- }
-
- TensorInfo qInfo(tensor.GetInfo().GetShape(), DataType::QuantisedAsymm8, scale, offset);
- return ConstTensor(qInfo, backing);
-}
-
-} // namespace
-
-QuantizerVisitor::QuantizerVisitor(armnn::StaticRangeVisitor* ranges)
-: m_Ranges(ranges)
-, m_QuantizedNetwork(INetwork::Create())
+QuantizerVisitor::QuantizerVisitor(const StaticRangeVisitor *staticRangeVisitor)
+ : m_StaticRangeVisitor(staticRangeVisitor)
+ , m_QuantizedNetwork(INetwork::Create())
{
+ BOOST_ASSERT(m_StaticRangeVisitor);
}
void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *srcLayer,
@@ -106,17 +30,17 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src
unsigned int slotIdx = outputSlot->CalculateIndexOnOwner();
Layer& layerToFind = outputSlot->GetOwningLayer();
- auto found = m_OldToNewGuidMap.find(layerToFind.GetGuid());
- if (found != m_OldToNewGuidMap.end())
+ auto found = m_OriginalToQuantizedGuidMap.find(layerToFind.GetGuid());
+ if (found != m_OriginalToQuantizedGuidMap.end())
{
// Connect the slots in the quantized model
- IConnectableLayer* prevQuantizedLayer = m_GuidToLayerMap[found->second];
+ IConnectableLayer* prevQuantizedLayer = m_QuantizedGuidToLayerMap[found->second];
IInputSlot& newInputSlot = quantizedLayer->GetInputSlot(i);
IOutputSlot& newOutputSlot = prevQuantizedLayer->GetOutputSlot(slotIdx);
newOutputSlot.Connect(newInputSlot);
// Fetch the min/max ranges that were computed earlier
- auto range = m_Ranges->GetRange(layerToFind.GetGuid(), i);
+ auto range = m_StaticRangeVisitor->GetRange(layerToFind.GetGuid(), i);
auto qParams = ComputeQAsymmParams(8, range.first, range.second);
// Set the quantization params
@@ -128,7 +52,7 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src
}
else
{
- // error in graph traversal order
+ // Error in graph traversal order
BOOST_ASSERT_MSG(false, "Error in graph traversal");
}
}
@@ -136,8 +60,8 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer *src
void QuantizerVisitor::RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer)
{
- m_OldToNewGuidMap[srcLayer->GetGuid()] = quantizedLayer->GetGuid();
- m_GuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer;
+ m_OriginalToQuantizedGuidMap[srcLayer->GetGuid()] = quantizedLayer->GetGuid();
+ m_QuantizedGuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer;
}
void QuantizerVisitor::VisitAdditionLayer(const IConnectableLayer *layer, const char *name)
@@ -200,4 +124,4 @@ void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *lay
SetQuantizedInputConnections(layer, newLayer);
}
-} //namespace armnn \ No newline at end of file
+} //namespace armnn