// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "NetworkQuantizer.hpp" #include "armnn/LayerVisitorBase.hpp" #include "RangeTracker.hpp" #include namespace armnn { class OverrideInputRangeStrategy : public IStrategy { private: using MinMaxRange = RangeTracker::MinMaxRange; public : OverrideInputRangeStrategy(RangeTracker& ranges, LayerBindingId layerId, const MinMaxRange& minMaxRange) : m_Ranges(ranges) , m_LayerId(layerId) , m_MinMaxRange(minMaxRange){} ~OverrideInputRangeStrategy() = default; void ExecuteStrategy(const armnn::IConnectableLayer* layer, const BaseDescriptor& descriptor, const std::vector& constants, const char* name, const armnn::LayerBindingId id) override { IgnoreUnused(name, constants, id, descriptor); switch (layer->GetType()) { case armnn::LayerType::Input : { if (m_LayerId == id) { m_Ranges.SetRange(layer, 0, m_MinMaxRange.first, m_MinMaxRange.second); } break; } default: { std::cout << "dont know this one" << std::endl; } } } private: /// Mapping from a layer Guid to an array of ranges for outputs RangeTracker& m_Ranges; /// The id of the input layer of which to override the input range LayerBindingId m_LayerId; /// The new input range to be applied to the input layer MinMaxRange m_MinMaxRange; }; /// Visitor object for overriding the input range of the quantized input layers in a network class OverrideInputRangeVisitor : public LayerVisitorBase { private: using MinMaxRange = RangeTracker::MinMaxRange; public: OverrideInputRangeVisitor(RangeTracker& ranges, LayerBindingId layerId, const MinMaxRange& minMaxRange); ~OverrideInputRangeVisitor() = default; void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; private: /// Mapping from a layer Guid to an array of ranges for outputs RangeTracker& m_Ranges; /// The id of the input layer of which to override the input range LayerBindingId m_LayerId; /// The new input range to be applied to the input layer MinMaxRange m_MinMaxRange; }; } // namespace armnn