aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/OverrideInputRangeVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/OverrideInputRangeVisitor.hpp')
-rw-r--r--src/armnn/OverrideInputRangeVisitor.hpp51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/armnn/OverrideInputRangeVisitor.hpp b/src/armnn/OverrideInputRangeVisitor.hpp
index 511c851bef..196a3aab1d 100644
--- a/src/armnn/OverrideInputRangeVisitor.hpp
+++ b/src/armnn/OverrideInputRangeVisitor.hpp
@@ -13,6 +13,57 @@
namespace armnn
{
+class OverrideInputRangeStrategy : public IStrategy
+{
+private:
+ using MinMaxRange = RangeTracker::MinMaxRange;
+public :
+ OverrideInputRangeStrategy(RangeTracker& ranges,
+ LayerBindingId layerId,
+ const MinMaxRange& minMaxRange)
+ : m_Ranges(ranges)
+ , m_LayerId(layerId)
+ , m_MinMaxRange(minMaxRange){}
+
+ ~OverrideInputRangeStrategy() = default;
+
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id) override
+ {
+ IgnoreUnused(name, constants, id, descriptor);
+
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input :
+ {
+ if (m_LayerId == id)
+ {
+ m_Ranges.SetRange(layer, 0, m_MinMaxRange.first, m_MinMaxRange.second);
+ }
+ break;
+ }
+ default:
+ {
+ std::cout << "dont know this one" << std::endl;
+ }
+ }
+ }
+
+private:
+ /// Mapping from a layer Guid to an array of ranges for outputs
+ RangeTracker& m_Ranges;
+
+ /// The id of the input layer of which to override the input range
+ LayerBindingId m_LayerId;
+
+ /// The new input range to be applied to the input layer
+ MinMaxRange m_MinMaxRange;
+};
+
+
/// Visitor object for overriding the input range of the quantized input layers in a network
class OverrideInputRangeVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>