aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 2890399cd8..505c9f8588 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -142,7 +142,7 @@ unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, unsigned int a
unsigned int numDim = shape.GetNumDimensions();
ARMNN_ASSERT(axis <= numDim - 1);
unsigned int count = 1;
- for (unsigned int i = axis; i < numDim; i++)
+ for (unsigned int i = axis+1; i < numDim; i++)
{
count *= shape[i];
}
@@ -159,7 +159,7 @@ std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::Tensor
std::string("Per-axis quantization params not set for tensor of type ") +
armnn::GetDataTypeName(info.GetDataType()), CHECK_LOCATION());
}
- unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value());
+ unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()) ;
return { axisFactor, scales };
}