diff options
-rw-r--r-- | include/armnn/Tensor.hpp | 12 | ||||
-rw-r--r-- | src/armnn/Tensor.cpp | 24 |
2 files changed, 28 insertions, 8 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index f4d7f9f984..9380a96af1 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -31,15 +31,9 @@ public: TensorShape& operator=(const TensorShape& other); - unsigned int operator[](unsigned int i) const - { - return m_Dimensions.at(i); - } + unsigned int operator[](unsigned int i) const; - unsigned int& operator[](unsigned int i) - { - return m_Dimensions.at(i); - } + unsigned int& operator[](unsigned int i); bool operator==(const TensorShape& other) const; bool operator!=(const TensorShape& other) const; @@ -50,6 +44,8 @@ public: private: std::array<unsigned int, MaxNumOfTensorDimensions> m_Dimensions; unsigned int m_NumDimensions; + + void CheckDimensionIndex(unsigned int i) const; }; class TensorInfo diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 6e09e3bc59..da19e5b97a 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -11,6 +11,8 @@ #include <boost/log/trivial.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <sstream> + namespace armnn { @@ -78,6 +80,18 @@ TensorShape& TensorShape::operator =(const TensorShape& other) return *this; } +unsigned int TensorShape::operator[](unsigned int i) const +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + +unsigned int& TensorShape::operator[](unsigned int i) +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + bool TensorShape::operator==(const TensorShape& other) const { return ((m_NumDimensions == other.m_NumDimensions) && @@ -105,6 +119,16 @@ unsigned int TensorShape::GetNumElements() const return count; } +void TensorShape::CheckDimensionIndex(unsigned int i) const +{ + if (i >= m_NumDimensions) + { + std::stringstream errorMessage; + errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")"; + throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION()); + } +} + // --- // --- TensorInfo // --- |