aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkQuantizationScheme.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/NetworkQuantizationScheme.hpp')
-rw-r--r--src/armnn/NetworkQuantizationScheme.hpp33
1 files changed, 33 insertions, 0 deletions
diff --git a/src/armnn/NetworkQuantizationScheme.hpp b/src/armnn/NetworkQuantizationScheme.hpp
index 0effa1fd64..ea3c29102b 100644
--- a/src/armnn/NetworkQuantizationScheme.hpp
+++ b/src/armnn/NetworkQuantizationScheme.hpp
@@ -61,6 +61,34 @@ struct QAsymm8QuantizationScheme : IQuantizationScheme
DataType GetDataType() const override { return DataType::QuantisedAsymm8; }
};
+struct QSymmS8QuantizationScheme : IQuantizationScheme
+{
+ OffsetScalePair ComputeScheme(double min, double max) const override
+ {
+ if (min > max)
+ {
+ throw InvalidArgumentException("min > max will result in invalid quantization.");
+ }
+
+ // To avoid dividing by zero when quantizing a zero filled tensor
+ if (min == 0.0 && max == 0.0)
+ {
+ max = 1.0;
+ }
+
+ double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
+
+ double extent = std::max(std::abs(min), std::abs(max));
+ double scale = extent / highest;
+
+ return std::make_pair(static_cast<float>(scale), 0);
+ }
+
+ int NumBits() const override { return 8; }
+
+ DataType GetDataType() const override { return DataType::QSymmS8; }
+};
+
struct QSymm16QuantizationScheme : IQuantizationScheme
{
OffsetScalePair ComputeScheme(double min, double max) const override
@@ -81,7 +109,12 @@ struct QSymm16QuantizationScheme : IQuantizationScheme
double extent = std::max(std::abs(min), std::abs(max));
double scale = extent / highest;
+ if(scale == 0.000457777642)
+ {
+ return std::make_pair(static_cast<float>(scale), 0);
+ }
return std::make_pair(static_cast<float>(scale), 0);
+
}
int NumBits() const override { return 16; }