aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp83
1 files changed, 83 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index fd08b2d2e5..afe3713036 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -7,6 +7,8 @@
#include "QuantizerVisitor.hpp"
#include "StaticRangeVisitor.hpp"
+#include "armnn/TypesUtils.hpp"
+
#include <cmath>
#include <stdint.h>
#include <limits>
@@ -34,6 +36,56 @@ std::pair<int, float> ComputeQAsymmParams(int numBits, double min, double max)
return std::make_pair(static_cast<int>(std::round(offset)), static_cast<float>(scale));
}
+template<typename srcType>
+void Quantize(const srcType* src, uint8_t* dst, size_t numElements, float &scale, int &offset)
+{
+ BOOST_ASSERT(src);
+ BOOST_ASSERT(dst);
+
+ float min = std::numeric_limits<srcType>::max();
+ float max = std::numeric_limits<srcType>::lowest();
+ for (size_t i = 0; i < numElements; ++i)
+ {
+ min = std::min(min, src[i]);
+ max = std::max(max, src[i]);
+ }
+
+ auto qParams = ComputeQAsymmParams(8, min, max);
+ offset = qParams.first;
+ scale = qParams.second;
+ for (size_t i = 0; i < numElements; ++i)
+ {
+ dst[i] = armnn::Quantize<uint8_t>(src[i], scale, offset);
+ }
+}
+
+ConstTensor CreateQuantizedConst(const ConstTensor& tensor, std::vector<uint8_t> &backing)
+{
+ float scale = 0.0f;
+ int offset = 0;
+ // Reserve the backing memory
+ backing.resize(tensor.GetInfo().GetNumElements());
+
+ DataType type = tensor.GetInfo().GetDataType();
+ switch(type)
+ {
+ case DataType::Float32:
+ {
+ Quantize(static_cast<const float*>( tensor.GetMemoryArea()),
+ backing.data(),
+ backing.size(),
+ scale,
+ offset);
+ }
+ break;
+ default:
+ BOOST_ASSERT_MSG(false, "Can't quantize unsupported data type");
+ }
+
+ TensorInfo qInfo(tensor.GetInfo().GetShape(), DataType::QuantisedAsymm8, scale, offset);
+ return ConstTensor(qInfo, backing);
+}
+
} // namespace
QuantizerVisitor::QuantizerVisitor(armnn::StaticRangeVisitor* ranges)
@@ -108,4 +160,35 @@ void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer *layer, LayerBin
SetQuantizedInputConnections(layer, newLayer);
}
+void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *layer,
+ const BatchNormalizationDescriptor& desc,
+ const ConstTensor& mean,
+ const ConstTensor& variance,
+ const ConstTensor& beta,
+ const ConstTensor& gamma,
+ const char *name)
+{
+ std::vector<uint8_t> meanBacking;
+ ConstTensor qMean = CreateQuantizedConst(mean, meanBacking);
+
+ std::vector<uint8_t> varianceBacking;
+ ConstTensor qVariance = CreateQuantizedConst(variance, varianceBacking);
+
+ std::vector<uint8_t> betaBacking;
+ ConstTensor qBeta = CreateQuantizedConst(beta, betaBacking);
+
+ std::vector<uint8_t> gammaBacking;
+ ConstTensor qGamma = CreateQuantizedConst(variance, gammaBacking);
+
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddBatchNormalizationLayer(desc,
+ qMean,
+ qVariance,
+ qBeta,
+ qGamma,
+ name);
+
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
} //namespace armnn \ No newline at end of file