diff options
Diffstat (limited to 'src/armnn/NetworkQuantizationScheme.hpp')
-rw-r--r-- | src/armnn/NetworkQuantizationScheme.hpp | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/src/armnn/NetworkQuantizationScheme.hpp b/src/armnn/NetworkQuantizationScheme.hpp index 0effa1fd64..ea3c29102b 100644 --- a/src/armnn/NetworkQuantizationScheme.hpp +++ b/src/armnn/NetworkQuantizationScheme.hpp @@ -61,6 +61,34 @@ struct QAsymm8QuantizationScheme : IQuantizationScheme DataType GetDataType() const override { return DataType::QuantisedAsymm8; } }; +struct QSymmS8QuantizationScheme : IQuantizationScheme +{ + OffsetScalePair ComputeScheme(double min, double max) const override + { + if (min > max) + { + throw InvalidArgumentException("min > max will result in invalid quantization."); + } + + // To avoid dividing by zero when quantizing a zero filled tensor + if (min == 0.0 && max == 0.0) + { + max = 1.0; + } + + double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit + + double extent = std::max(std::abs(min), std::abs(max)); + double scale = extent / highest; + + return std::make_pair(static_cast<float>(scale), 0); + } + + int NumBits() const override { return 8; } + + DataType GetDataType() const override { return DataType::QSymmS8; } +}; + struct QSymm16QuantizationScheme : IQuantizationScheme { OffsetScalePair ComputeScheme(double min, double max) const override @@ -81,7 +109,12 @@ struct QSymm16QuantizationScheme : IQuantizationScheme double extent = std::max(std::abs(min), std::abs(max)); double scale = extent / highest; + if(scale == 0.000457777642) + { + return std::make_pair(static_cast<float>(scale), 0); + } return std::make_pair(static_cast<float>(scale), 0); + } int NumBits() const override { return 16; } |