diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/Tensor.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 6e09e3bc59..da19e5b97a 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -11,6 +11,8 @@ #include <boost/log/trivial.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <sstream> + namespace armnn { @@ -78,6 +80,18 @@ TensorShape& TensorShape::operator =(const TensorShape& other) return *this; } +unsigned int TensorShape::operator[](unsigned int i) const +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + +unsigned int& TensorShape::operator[](unsigned int i) +{ + CheckDimensionIndex(i); + return m_Dimensions.at(i); +} + bool TensorShape::operator==(const TensorShape& other) const { return ((m_NumDimensions == other.m_NumDimensions) && @@ -105,6 +119,16 @@ unsigned int TensorShape::GetNumElements() const return count; } +void TensorShape::CheckDimensionIndex(unsigned int i) const +{ + if (i >= m_NumDimensions) + { + std::stringstream errorMessage; + errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")"; + throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION()); + } +} + // --- // --- TensorInfo // --- |