// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "LayerVisitorBase.hpp" #include "StaticRangeVisitor.hpp" #include #include #include #include namespace armnn { // Forward declaration class StaticRangeVisitor; /// Visitor object for quantizing layers in a network class QuantizerVisitor : public LayerVisitorBase { public: QuantizerVisitor(const StaticRangeVisitor* staticRangeVisitor); ~QuantizerVisitor() = default; /// Functions to quantize the individual layers, overridden from ILayerVisitor void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) override; void VisitActivationLayer(const IConnectableLayer* layer, const ActivationDescriptor& activationDescriptor, const char* name = nullptr) override; void VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; void VisitBatchNormalizationLayer(const IConnectableLayer* layer, const BatchNormalizationDescriptor& desc, const ConstTensor& mean, const ConstTensor& variance, const ConstTensor& beta, const ConstTensor& gamma, const char* name = nullptr) override; void VisitFullyConnectedLayer(const IConnectableLayer *layer, const FullyConnectedDescriptor&, const ConstTensor&, const Optional&, const char *name = nullptr) override; void VisitConvolution2dLayer(const IConnectableLayer* layer, const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, const Optional& biases, const char* name = nullptr) override; void VisitSoftmaxLayer(const IConnectableLayer* layer, const SoftmaxDescriptor& softmaxDescriptor, const char* name = nullptr) override; /// Extract the quantized network INetworkPtr RetrieveFinalNetwork() { return std::move(m_QuantizedNetwork); } private: /// Connects the layer to preceeding layers and sets the quantization parameters based on recorded ranges void SetQuantizedInputConnections(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer); /// Record the guids so we can easily find the layers later void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer); /// Reference to the static range visitor used to retrieve the quantization ranges const StaticRangeVisitor* m_StaticRangeVisitor; /// Quantized version of the model we are building up INetworkPtr m_QuantizedNetwork; /// Mapping from input network guids to quantized network guids std::unordered_map m_OriginalToQuantizedGuidMap; /// Mapping from guid to layer in quantized network std::unordered_map m_QuantizedGuidToLayerMap; }; } //namespace armnn