diff options
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | include/armnn/utility/NumericCast.hpp | 124 | ||||
-rw-r--r-- | src/armnn/test/UtilityTests.cpp | 47 |
3 files changed, 172 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d63e43c04..5002eb4e0b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -249,6 +249,7 @@ list(APPEND armnn_sources include/armnn/profiling/ISendTimelinePacket.hpp include/armnn/utility/Assert.hpp include/armnn/utility/IgnoreUnused.hpp + include/armnn/utility/NumericCast.hpp include/armnn/utility/PolymorphicDowncast.hpp profiling/common/include/SocketConnectionException.hpp src/armnn/layers/LayerCloneBase.hpp diff --git a/include/armnn/utility/NumericCast.hpp b/include/armnn/utility/NumericCast.hpp new file mode 100644 index 0000000000..62c7d11543 --- /dev/null +++ b/include/armnn/utility/NumericCast.hpp @@ -0,0 +1,124 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Assert.hpp" + +#include <type_traits> +#include <limits> + +namespace armnn +{ + +#if !defined(NDEBUG) || defined(ARMNN_NUMERIC_CAST_TESTABLE) +#define ENABLE_NUMERIC_CAST_CHECKS 1 +#else +#define ENABLE_NUMERIC_CAST_CHECKS 0 +#endif + +#if defined(ARMNN_NUMERIC_CAST_TESTABLE) +# define ARMNN_NUMERIC_CAST_CHECK(cond, msg) ConditionalThrow<std::bad_cast>(cond) +#else +# define ARMNN_NUMERIC_CAST_CHECK(cond, msg) ARMNN_ASSERT_MSG(cond, msg) +#endif + +template<typename Dest, typename Source> +typename std::enable_if_t< + std::is_unsigned<Source>::value && + std::is_unsigned<Dest>::value + , Dest> +numeric_cast(Source source) +{ +#if ENABLE_NUMERIC_CAST_CHECKS + if (source > std::numeric_limits<Dest>::max()) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting unsigned type to " + "narrower unsigned type. Overflow detected."); + } +#endif // ENABLE_NUMERIC_CAST_CHECKS + + return static_cast<Dest>(source); +} + +template<typename Dest, typename Source> +typename std::enable_if_t< + std::is_signed<Source>::value && + std::is_signed<Dest>::value + , Dest> +numeric_cast(Source source) +{ + static_assert(!std::is_floating_point<Source>::value && !std::is_floating_point<Dest>::value, + "numeric_cast doesn't cast float."); + +#if ENABLE_NUMERIC_CAST_CHECKS + if (source > std::numeric_limits<Dest>::max()) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting signed type to narrower signed type. " + "Overflow detected."); + } + + if (source < std::numeric_limits<Dest>::lowest()) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting signed type to narrower signed type. " + "Underflow detected."); + } +#endif // ENABLE_NUMERIC_CAST_CHECKS + + return static_cast<Dest>(source); +} + +// numeric cast from unsigned to signed checked for narrowing overflows +template<typename Dest, typename Source> +typename std::enable_if_t< + std::is_signed<Dest>::value && + std::is_unsigned<Source>::value + , Dest> +numeric_cast(Source sValue) +{ + static_assert(!std::is_floating_point<Dest>::value, "numeric_cast doesn't cast to float."); + +#if ENABLE_NUMERIC_CAST_CHECKS + if (sValue > static_cast< typename std::make_unsigned<Dest>::type >(std::numeric_limits<Dest>::max())) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting unsigned type to signed type. " + "Overflow detected."); + } +#endif // ENABLE_NUMERIC_CAST_CHECKS + + return static_cast<Dest>(sValue); +} + +// numeric cast from signed to unsigned checked for underflows and narrowing overflows +template<typename Dest, typename Source> +typename std::enable_if_t< + std::is_unsigned<Dest>::value && + std::is_signed<Source>::value + , Dest> +numeric_cast(Source sValue) +{ + static_assert(!std::is_floating_point<Source>::value && !std::is_floating_point<Dest>::value, + "numeric_cast doesn't cast floats."); + +#if ENABLE_NUMERIC_CAST_CHECKS + if (sValue < 0) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting negative value to unsigned type. " + "Underflow detected."); + } + + if (static_cast< typename std::make_unsigned<Source>::type >(sValue) > std::numeric_limits<Dest>::max()) + { + ARMNN_NUMERIC_CAST_CHECK(false, "numeric_cast failed casting signed type to unsigned type. " + "Overflow detected."); + } + +#endif // ENABLE_NUMERIC_CAST_CHECKS + return static_cast<Dest>(sValue); +} + +#undef ENABLE_NUMERIC_CAST_CHECKS + +} //namespace armnn
\ No newline at end of file diff --git a/src/armnn/test/UtilityTests.cpp b/src/armnn/test/UtilityTests.cpp index 5309d82ce4..7be5c9518a 100644 --- a/src/armnn/test/UtilityTests.cpp +++ b/src/armnn/test/UtilityTests.cpp @@ -8,9 +8,11 @@ #include <boost/polymorphic_cast.hpp> #define ARMNN_POLYMORPHIC_CAST_TESTABLE +#define ARMNN_NUMERIC_CAST_TESTABLE #include <armnn/utility/IgnoreUnused.hpp> #include <armnn/utility/PolymorphicDowncast.hpp> +#include <armnn/utility/NumericCast.hpp> #include <armnn/Exceptions.hpp> @@ -53,4 +55,49 @@ BOOST_AUTO_TEST_CASE(PolymorphicDowncast) armnn::IgnoreUnused(ptr1, ptr2); } + +BOOST_AUTO_TEST_CASE(NumericCast) +{ + using namespace armnn; + + // To 8 bit + BOOST_CHECK_THROW(numeric_cast<unsigned char>(-1), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<unsigned char>(1 << 8), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<unsigned char>(1L << 16), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<unsigned char>(1LL << 32), std::bad_cast); + + BOOST_CHECK_THROW(numeric_cast<signed char>((1L << 8)*-1), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<signed char>((1L << 15)*-1), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<signed char>((1LL << 31)*-1), std::bad_cast); + + BOOST_CHECK_NO_THROW(numeric_cast<unsigned char>(1U)); + BOOST_CHECK_NO_THROW(numeric_cast<unsigned char>(1L)); + BOOST_CHECK_NO_THROW(numeric_cast<signed char>(-1)); + BOOST_CHECK_NO_THROW(numeric_cast<signed char>(-1L)); + BOOST_CHECK_NO_THROW(numeric_cast<signed char>((1 << 7)*-1)); + + // To 16 bit + BOOST_CHECK_THROW(numeric_cast<uint16_t>(-1), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<uint16_t>(1L << 16), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<uint16_t>(1LL << 32), std::bad_cast); + + BOOST_CHECK_THROW(numeric_cast<int16_t>(1L << 15), std::bad_cast); + BOOST_CHECK_THROW(numeric_cast<int16_t>(1LL << 31), std::bad_cast); + + BOOST_CHECK_NO_THROW(numeric_cast<uint16_t>(1L << 8)); + BOOST_CHECK_NO_THROW(numeric_cast<int16_t>(1L << 7)); + BOOST_CHECK_NO_THROW(numeric_cast<int16_t>((1L << 15)*-1)); + + // To 32 bit + BOOST_CHECK_NO_THROW(numeric_cast<uint32_t>(1)); + BOOST_CHECK_NO_THROW(numeric_cast<uint32_t>(1 << 8)); + BOOST_CHECK_NO_THROW(numeric_cast<uint32_t>(1L << 16)); + BOOST_CHECK_NO_THROW(numeric_cast<uint32_t>(1LL << 31)); + + BOOST_CHECK_NO_THROW(numeric_cast<int32_t>(-1)); + BOOST_CHECK_NO_THROW(numeric_cast<int32_t>((1L << 8)*-1)); + BOOST_CHECK_NO_THROW(numeric_cast<int32_t>((1L << 16)*-1)); + BOOST_CHECK_NO_THROW(numeric_cast<int32_t>((1LL << 31)*-1)); +} + BOOST_AUTO_TEST_SUITE_END() |