diff options
-rw-r--r-- | include/armnn/Types.hpp | 16 | ||||
-rw-r--r-- | python/pyarmnn/test/test_types.py | 5 | ||||
-rw-r--r-- | src/armnn/test/UtilsTests.cpp | 22 |
3 files changed, 42 insertions, 1 deletions
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index b5a4266e36..880a6dd816 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -306,7 +306,21 @@ public: PermutationVector(std::initializer_list<ValueType> dimMappings); - ValueType operator[](SizeType i) const { return m_DimMappings.at(i); } + /// + /// Indexing method with out-of-bounds error checking for the m_DimMappings array. + /// @param i - integer value corresponding to index of m_DimMappings array to retrieve element from. + /// @return element at index i of m_DimMappings array. + /// @throws InvalidArgumentException when indexing out-of-bounds index of m_DimMappings array. + /// + ValueType operator[](SizeType i) const + { + if (i >= GetSize()) + { + throw InvalidArgumentException("Invalid indexing of PermutationVector of size " + std::to_string(GetSize()) + + " at location [" + std::to_string(i) + "]."); + } + return m_DimMappings.at(i); + } SizeType GetSize() const { return m_NumDimMappings; } diff --git a/python/pyarmnn/test/test_types.py b/python/pyarmnn/test/test_types.py index dfe1429c02..500a779844 100644 --- a/python/pyarmnn/test/test_types.py +++ b/python/pyarmnn/test/test_types.py @@ -27,3 +27,8 @@ def test_permutation_vector(): pv4 = ann.PermutationVector((0, 3, 1, 2)) assert pv.IsInverse(pv4) + + with pytest.raises(ValueError) as err: + pv4[4] + + assert err.type is ValueError diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp index 1599d0cd35..63884374b3 100644 --- a/src/armnn/test/UtilsTests.cpp +++ b/src/armnn/test/UtilsTests.cpp @@ -269,6 +269,28 @@ TEST_CASE("PermuteQuantizationDim") CHECK(permuted.GetQuantizationDim().value() == 3U); } +TEST_CASE("EmptyPermuteVectorIndexOutOfBounds") +{ + armnn::PermutationVector pv = armnn::PermutationVector({}); + CHECK_THROWS_AS(pv[0], armnn::InvalidArgumentException); +} + +TEST_CASE("PermuteDescriptorIndexOutOfBounds") +{ + armnn::PermutationVector pv = armnn::PermutationVector({ 1u, 2u, 0u }); + armnn::PermuteDescriptor desc = armnn::PermuteDescriptor(pv); + CHECK_THROWS_AS(desc.m_DimMappings[3], armnn::InvalidArgumentException); + CHECK(desc.m_DimMappings[0] == 1u); +} + +TEST_CASE("TransposeDescriptorIndexOutOfBounds") +{ + armnn::PermutationVector pv = armnn::PermutationVector({ 2u, 1u, 0u }); + armnn::TransposeDescriptor desc = armnn::TransposeDescriptor(pv); + CHECK_THROWS_AS(desc.m_DimMappings[3], armnn::InvalidArgumentException); + CHECK(desc.m_DimMappings[2] == 0u); +} + TEST_CASE("PermuteVectorIterator") { // We're slightly breaking the spirit of std::array.end() because we're using it as a |