diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index fd08b2d2e5..afe3713036 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -7,6 +7,8 @@ #include "QuantizerVisitor.hpp" #include "StaticRangeVisitor.hpp" +#include "armnn/TypesUtils.hpp" + #include <cmath> #include <stdint.h> #include <limits> @@ -34,6 +36,56 @@ std::pair<int, float> ComputeQAsymmParams(int numBits, double min, double max) return std::make_pair(static_cast<int>(std::round(offset)), static_cast<float>(scale)); } +template<typename srcType> +void Quantize(const srcType* src, uint8_t* dst, size_t numElements, float &scale, int &offset) +{ + BOOST_ASSERT(src); + BOOST_ASSERT(dst); + + float min = std::numeric_limits<srcType>::max(); + float max = std::numeric_limits<srcType>::lowest(); + for (size_t i = 0; i < numElements; ++i) + { + min = std::min(min, src[i]); + max = std::max(max, src[i]); + } + + auto qParams = ComputeQAsymmParams(8, min, max); + offset = qParams.first; + scale = qParams.second; + for (size_t i = 0; i < numElements; ++i) + { + dst[i] = armnn::Quantize<uint8_t>(src[i], scale, offset); + } +} + +ConstTensor CreateQuantizedConst(const ConstTensor& tensor, std::vector<uint8_t> &backing) +{ + float scale = 0.0f; + int offset = 0; + // Reserve the backing memory + backing.resize(tensor.GetInfo().GetNumElements()); + + DataType type = tensor.GetInfo().GetDataType(); + switch(type) + { + case DataType::Float32: + { + Quantize(static_cast<const float*>( tensor.GetMemoryArea()), + backing.data(), + backing.size(), + scale, + offset); + } + break; + default: + BOOST_ASSERT_MSG(false, "Can't quantize unsupported data type"); + } + + TensorInfo qInfo(tensor.GetInfo().GetShape(), DataType::QuantisedAsymm8, scale, offset); + return ConstTensor(qInfo, backing); +} + } // namespace QuantizerVisitor::QuantizerVisitor(armnn::StaticRangeVisitor* ranges) @@ -108,4 +160,35 @@ void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer *layer, LayerBin SetQuantizedInputConnections(layer, newLayer); } +void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer *layer, + const BatchNormalizationDescriptor& desc, + const ConstTensor& mean, + const ConstTensor& variance, + const ConstTensor& beta, + const ConstTensor& gamma, + const char *name) +{ + std::vector<uint8_t> meanBacking; + ConstTensor qMean = CreateQuantizedConst(mean, meanBacking); + + std::vector<uint8_t> varianceBacking; + ConstTensor qVariance = CreateQuantizedConst(variance, varianceBacking); + + std::vector<uint8_t> betaBacking; + ConstTensor qBeta = CreateQuantizedConst(beta, betaBacking); + + std::vector<uint8_t> gammaBacking; + ConstTensor qGamma = CreateQuantizedConst(variance, gammaBacking); + + IConnectableLayer* newLayer = m_QuantizedNetwork->AddBatchNormalizationLayer(desc, + qMean, + qVariance, + qBeta, + qGamma, + name); + + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + } //namespace armnn
\ No newline at end of file |