aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/TypesUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/TypesUtils.hpp')
-rw-r--r--include/armnn/TypesUtils.hpp15
1 files changed, 7 insertions, 8 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index c65eefc510..e7652649b0 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -189,18 +189,16 @@ inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & s
template<typename QuantizedType>
inline QuantizedType Quantize(float value, float scale, int32_t offset)
{
- // TODO : check we act sensibly for Inf, NaN and -Inf
- // see IVGCVSW-1849
static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
constexpr QuantizedType max = std::numeric_limits<QuantizedType>::max();
constexpr QuantizedType min = std::numeric_limits<QuantizedType>::lowest();
BOOST_ASSERT(scale != 0.f);
- int quantized = boost::numeric_cast<int>(round(value / scale)) + offset;
- QuantizedType quantizedBits = quantized <= min
- ? min
- : quantized >= max
- ? max
- : static_cast<QuantizedType>(quantized);
+ BOOST_ASSERT(!std::isnan(value));
+
+ float clampedValue = std::min(std::max(static_cast<float>(round(value/scale) + offset), static_cast<float>(min)),
+ static_cast<float>(max));
+ auto quantizedBits = static_cast<QuantizedType>(clampedValue);
+
return quantizedBits;
}
@@ -215,6 +213,7 @@ inline float Dequantize(QuantizedType value, float scale, int32_t offset)
{
static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
BOOST_ASSERT(scale != 0.f);
+ BOOST_ASSERT(!std::isnan(value));
float dequantized = boost::numeric_cast<float>(value - offset) * scale;
return dequantized;
}