aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorantkillerfarm <antkillerfarm@gmail.com>2020-10-15 11:02:07 +0800
committermike.kelly <mike.kelly@arm.com>2020-10-23 10:39:21 +0000
commitdb6e8a99b024a697cce1ca5198724c3805440b2a (patch)
tree53b1bef4ebc67ec5749bc9b24c616a92a5ec1c92 /src
parent5f960d92dbefabf708e8f299789b35f6cdf2d919 (diff)
downloadarmnn-db6e8a99b024a697cce1ca5198724c3805440b2a.tar.gz
GitHub#465 Fix NonMaxSuppression
If visited flag set true, it should not be visited any more. For example, if we put 10 boxes (ordered by score) into NonMaxSuppression: * Step1: Suppose Box 2/3/6/8 are suppressed by Box 1. Box 4/5/7/9/10 survived. * Step2: Correct way: We use Box 4 to suppress the survive boxes. Prior to this commit: Box 4 may be suppressed by Box 2, even Box 2 is already suppressed by Box 1... Signed-off-by: Antkillerfarm <antkillerfarm@gmail.com> Change-Id: I38d7a84287649827a16565748592fb562b4df5d5
Diffstat (limited to 'src')
-rw-r--r--src/backends/reference/workloads/DetectionPostProcess.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/backends/reference/workloads/DetectionPostProcess.cpp b/src/backends/reference/workloads/DetectionPostProcess.cpp
index f80f20a441..2108efe8f3 100644
--- a/src/backends/reference/workloads/DetectionPostProcess.cpp
+++ b/src/backends/reference/workloads/DetectionPostProcess.cpp
@@ -85,14 +85,14 @@ std::vector<unsigned int> NonMaxSuppression(unsigned int numBoxes,
if (!visited[sortedIndices[i]])
{
outputIndices.push_back(indicesAboveThreshold[sortedIndices[i]]);
- }
- for (unsigned int j = i + 1; j < numAboveThreshold; ++j)
- {
- unsigned int iIndex = indicesAboveThreshold[sortedIndices[i]] * 4;
- unsigned int jIndex = indicesAboveThreshold[sortedIndices[j]] * 4;
- if (IntersectionOverUnion(&boxCorners[iIndex], &boxCorners[jIndex]) > nmsIouThreshold)
+ for (unsigned int j = i + 1; j < numAboveThreshold; ++j)
{
- visited[sortedIndices[j]] = true;
+ unsigned int iIndex = indicesAboveThreshold[sortedIndices[i]] * 4;
+ unsigned int jIndex = indicesAboveThreshold[sortedIndices[j]] * 4;
+ if (IntersectionOverUnion(&boxCorners[iIndex], &boxCorners[jIndex]) > nmsIouThreshold)
+ {
+ visited[sortedIndices[j]] = true;
+ }
}
}
}