aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp5
1 files changed, 2 insertions, 3 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 630490ff14..601277491c 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -142,7 +142,7 @@ unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, unsigned int a
{
unsigned int numDim = shape.GetNumDimensions();
BOOST_ASSERT(0 >= axis);
- BOOST_ASSERT(axis < numDim - 1);
+ BOOST_ASSERT(axis <= numDim - 1);
unsigned int count = 1;
for (unsigned int i = axis; i < numDim; i++)
{
@@ -155,7 +155,7 @@ std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::Tensor
{
const std::vector<float>& scales = info.GetQuantizationScales();
armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
- if (scales.size() < 1 || !quantizationDim.has_value())
+ if (!info.HasPerAxisQuantization())
{
throw armnn::InvalidArgumentException(
std::string("Per-axis quantization params not set for tensor of type ") +
@@ -166,5 +166,4 @@ std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::Tensor
return { axisFactor, scales };
}
-
} // namespace armnnUtils