aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/StaticRangeVisitor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/StaticRangeVisitor.cpp')
-rw-r--r--src/armnn/StaticRangeVisitor.cpp31
1 files changed, 31 insertions, 0 deletions
diff --git a/src/armnn/StaticRangeVisitor.cpp b/src/armnn/StaticRangeVisitor.cpp
index b1cbb2d574..cc1255e56e 100644
--- a/src/armnn/StaticRangeVisitor.cpp
+++ b/src/armnn/StaticRangeVisitor.cpp
@@ -9,6 +9,8 @@
#include <armnn/Descriptors.hpp>
#include <armnn/Types.hpp>
+#include <limits>
+
namespace armnn
{
@@ -150,4 +152,33 @@ void StaticRangeVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer,
SetRange(layer, 0, 0.f, 1.f);
}
+void StaticRangeVisitor::VisitConstantLayer(const IConnectableLayer* layer,
+ const ConstTensor& input,
+ const char* name)
+{
+ boost::ignore_unused(name);
+
+ if (input.GetDataType() != DataType::Float32)
+ {
+ throw InvalidArgumentException("Quantization is supported only for FP32 tensors");
+ }
+
+ // Work out the range based on the input constants
+ unsigned int inputNumElements = input.GetNumElements();
+ const float* inputData = reinterpret_cast<const float*>(input.GetMemoryArea());
+
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+
+ for (unsigned int i = 0; i < inputNumElements; i++)
+ {
+ const float inputValue = inputData[i];
+
+ min = std::min(min, inputValue);
+ max = std::max(max, inputValue);
+ }
+
+ SetRange(layer, 0, min, max);
+}
+
} //namespace armnn