diff options
author | Derek Lamberti <derek.lamberti@arm.com> | 2019-02-07 11:14:11 +0000 |
---|---|---|
committer | Derek Lamberti <derek.lamberti@arm.com> | 2019-02-07 13:21:28 +0000 |
commit | 857aa45407df9dbe99a11d03a4be2b20bd0110ae (patch) | |
tree | 3e47a2aa1ac8787a00900eff0ba49246ef9a4bdc /src/armnn/QuantizerVisitor.cpp | |
parent | 49dbe0e9f6747583cff29ada68d6670796d4216c (diff) | |
download | armnn-857aa45407df9dbe99a11d03a4be2b20bd0110ae.tar.gz |
IVGCVSW-2609 Quantize BatchNormalizationLayer
Change-Id: I7b847112a0322ffc8b88a0708d8439bfb97cfe2c
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index fd08b2d2e5..afe3713036 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -7,6 +7,8 @@ #include "QuantizerVisitor.hpp" #include "StaticRangeVisitor.hpp" +#include "armnn/TypesUtils.hpp" + #include <cmath> #include <stdint.h> #include <limits> @@ -34,6 +36,56 @@ std::pair<int, float> ComputeQAsymmParams(int numBits, double min, double max) return std::make_pair(static_cast<int>(std::round(offset)), static_cast<float>(scale)); } +template<typename srcType> +void Quantize(const srcType* src, uint8_t* dst, size_t numElements, float &scale, int &offset) +{ + BOOST_ASSERT(src); + BOOST_ASSERT(dst); + + float min = std::numeric_limits<srcType>::max(); + float max = std::numeric_limits<srcType>::lowest(); + for (size_t i = 0; i < numElements; ++i) + { + min = std::min(min, src[i]); + max = std::max(max, src[i]); + } + + auto qParams = ComputeQAsymmParams(8, min, max); + offset = qParams.first; + scale = qParams.second; + for (size_t i = 0; i < numElements; ++i) + { + dst[i] = armnn::Quantize<uint8_t>(src[i], scale, offset); + } +} + +ConstTensor CreateQuantizedConst(const ConstTensor& tensor, std::vector<uint8_t> &backing) +{ + float scale = 0.0f; + int offset = 0; + // Reserve the backing memory + backing.resize(tensor.GetInfo().GetNumElements()); + + DataType type = tensor.GetInfo().GetDataType(); + switch(type) + { + case DataType::Float32: + { + Quantize(static_cast<const float*>( tensor.GetMemoryArea()), + backing.data(), + backing.size(), + scale, + offset); + } + break; + default: + BOOST_ASSERT_MSG(false, "Can't quantize unsupported data type"); + } + + TensorInfo qInfo(tensor.GetInfo().GetShape(), DataType::QuantisedAsymm8, scale, offset); + return ConstTensor(qInfo, backing); +} + } // namespace QuantizerVisitor::QuantizerVisitor(armnn::StaticRangeVisitor* ranges) @@ -108,4 +160,35 @@ void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer *layer, LayerBin SetQuantizedInputConnections(layer, newLayer); } +void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *layer, + const BatchNormalizationDescriptor& desc, + const ConstTensor& mean, + const ConstTensor& variance, + const ConstTensor& beta, + const ConstTensor& gamma, + const char *name) +{ + std::vector<uint8_t> meanBacking; + ConstTensor qMean = CreateQuantizedConst(mean, meanBacking); + + std::vector<uint8_t> varianceBacking; + ConstTensor qVariance = CreateQuantizedConst(variance, varianceBacking); + + std::vector<uint8_t> betaBacking; + ConstTensor qBeta = CreateQuantizedConst(beta, betaBacking); + + std::vector<uint8_t> gammaBacking; + ConstTensor qGamma = CreateQuantizedConst(variance, gammaBacking); + + IConnectableLayer* newLayer = m_QuantizedNetwork->AddBatchNormalizationLayer(desc, + qMean, + qVariance, + qBeta, + qGamma, + name); + + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + } //namespace armnn
\ No newline at end of file |