From 06e25c41e8727cc859c2b6d1988a988e90bb537b Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Thu, 21 Feb 2019 15:45:03 +0000 Subject: IVGCVSW-2749 Throw exception in TensorShape when requested index >= number of dimensions Change-Id: I3589b1e901b0f81f6bb17848046a22829f91bb9e Signed-off-by: Aron Virginas-Tar --- include/armnn/Tensor.hpp | 12 ++++-------- src/armnn/Tensor.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index f4d7f9f984..9380a96af1 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -31,15 +31,9 @@ public: TensorShape& operator=(const TensorShape& other); - unsigned int operator[](unsigned int i) const - { - return m_Dimensions.at(i); - } + unsigned int operator[](unsigned int i) const; - unsigned int& operator[](unsigned int i) - { - return m_Dimensions.at(i); - } + unsigned int& operator[](unsigned int i); bool operator==(const TensorShape& other) const; bool operator!=(const TensorShape& other) const; @@ -50,6 +44,8 @@ public: private: std::array m_Dimensions; unsigned int m_NumDimensions; + + void CheckDimensionIndex(unsigned int i) const; }; class TensorInfo diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 6e09e3bc59..da19e5b97a 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace armnn { @@ -78,6 +80,18 @@ TensorShape& TensorShape::operator =(const TensorShape& other) return *this; } +unsigned int TensorShape::operator[](unsigned int i) const +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + +unsigned int& TensorShape::operator[](unsigned int i) +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + bool TensorShape::operator==(const TensorShape& other) const { return ((m_NumDimensions == other.m_NumDimensions) && @@ -105,6 +119,16 @@ unsigned int TensorShape::GetNumElements() const return count; } +void TensorShape::CheckDimensionIndex(unsigned int i) const +{ + if (i >= m_NumDimensions) + { + std::stringstream errorMessage; + errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")"; + throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION()); + } +} + // --- // --- TensorInfo // --- -- cgit v1.2.1