diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.hpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.hpp | 42 |
1 files changed, 25 insertions, 17 deletions
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp index 0dc45822b4..dcaccd4ac7 100644 --- a/src/armnn/QuantizerVisitor.hpp +++ b/src/armnn/QuantizerVisitor.hpp @@ -6,31 +6,34 @@ #pragma once #include "LayerVisitorBase.hpp" +#include "StaticRangeVisitor.hpp" + #include <armnn/INetwork.hpp> #include <armnn/Types.hpp> +#include <armnn/INetworkQuantizer.hpp> -#include <map> +#include <unordered_map> namespace armnn { -// Forward declarations +// Forward declaration class StaticRangeVisitor; /// Visitor object for quantizing layers in a network class QuantizerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy> { public: - QuantizerVisitor(StaticRangeVisitor* ranges); + QuantizerVisitor(const StaticRangeVisitor* staticRangeVisitor); ~QuantizerVisitor() = default; - // Functions to quantize the individual layers, overridden from ILayerVisitor - void VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name = nullptr) override; - void VisitAdditionLayer(const IConnectableLayer *layer, const char *name = nullptr) override; - void VisitActivationLayer(const IConnectableLayer *layer, + /// Functions to quantize the individual layers, overridden from ILayerVisitor + void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; + void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) override; + void VisitActivationLayer(const IConnectableLayer* layer, const ActivationDescriptor& activationDescriptor, - const char *name = nullptr) override; - void VisitOutputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name = nullptr) override; + const char* name = nullptr) override; + void VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; void VisitBatchNormalizationLayer(const IConnectableLayer* layer, const BatchNormalizationDescriptor& desc, const ConstTensor& mean, @@ -39,22 +42,27 @@ public: const ConstTensor& gamma, const char* name = nullptr) override; - // Extract the quantized network + /// Extract the quantized network INetworkPtr RetrieveFinalNetwork() { return std::move(m_QuantizedNetwork); } -private: +private: /// Connects the layer to preceeding layers and sets the quantization parameters based on recorded ranges - void SetQuantizedInputConnections(const IConnectableLayer *srcLayer, IConnectableLayer *quantizedLayer); + void SetQuantizedInputConnections(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer); /// Record the guids so we can easily find the layers later void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer); + /// Reference to the static range visitor used to retrieve the quantization ranges + const StaticRangeVisitor* m_StaticRangeVisitor; + + /// Quantized version of the model we are building up + INetworkPtr m_QuantizedNetwork; - StaticRangeVisitor* m_Ranges; ///< Previously recorded min/max ranges per intermediate tensor - INetworkPtr m_QuantizedNetwork; ///< Quantized version of the model we are building up + /// Mapping from input network guids to quantized network guids + std::unordered_map<LayerGuid, LayerGuid> m_OriginalToQuantizedGuidMap; - std::map<LayerGuid, LayerGuid> m_OldToNewGuidMap; ///< Mapping from input network guids to quantized network guids - std::map<LayerGuid, IConnectableLayer*> m_GuidToLayerMap; ///< Mapping from guid to layer in quantized network + /// Mapping from guid to layer in quantized network + std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap; }; -} //namespace armnn
\ No newline at end of file +} //namespace armnn |