aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Tensor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Tensor.cpp')
-rw-r--r--src/armnn/Tensor.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp
index 6e09e3bc59..da19e5b97a 100644
--- a/src/armnn/Tensor.cpp
+++ b/src/armnn/Tensor.cpp
@@ -11,6 +11,8 @@
#include <boost/log/trivial.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <sstream>
+
namespace armnn
{
@@ -78,6 +80,18 @@ TensorShape& TensorShape::operator =(const TensorShape& other)
return *this;
}
+unsigned int TensorShape::operator[](unsigned int i) const
+{
+ CheckDimensionIndex(i);
+ return m_Dimensions.at(i);
+}
+
+unsigned int& TensorShape::operator[](unsigned int i)
+{
+ CheckDimensionIndex(i);
+ return m_Dimensions.at(i);
+}
+
bool TensorShape::operator==(const TensorShape& other) const
{
return ((m_NumDimensions == other.m_NumDimensions) &&
@@ -105,6 +119,16 @@ unsigned int TensorShape::GetNumElements() const
return count;
}
+void TensorShape::CheckDimensionIndex(unsigned int i) const
+{
+ if (i >= m_NumDimensions)
+ {
+ std::stringstream errorMessage;
+ errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")";
+ throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
+ }
+}
+
// ---
// --- TensorInfo
// ---