aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2018-10-29 17:39:49 +0000
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-29 17:41:53 +0000
commit4322d36a2d6d9fca16a661019b8c5dac0c1e81ec (patch)
tree72fdebd79624f268707cc640d1f153b5d2341bd1
parent382a91d5029e83002bda4ab006f9c73340d679fe (diff)
downloadarmnn-4322d36a2d6d9fca16a661019b8c5dac0c1e81ec.tar.gz
IVGCVSW-2029 Tweak results handling for batch size 2 test
When looking for the top probability, use the 'first' result not the 'second'. This avoids an issue where for batched tests the classification index was reported wrongly. Still doesn't correctly handle multiple results with the exact probabibility, or batched testing, but it's slightly more correct than before. Change-Id: I57d33552754667613e222d9d2037e12c87a96854
-rw-r--r--tests/InferenceTest.inl13
-rw-r--r--tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp1
2 files changed, 12 insertions, 2 deletions
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl
index 5e858f06d3..7ce017c6cd 100644
--- a/tests/InferenceTest.inl
+++ b/tests/InferenceTest.inl
@@ -60,7 +60,18 @@ TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(cons
int index = 0;
for (const auto & o : output)
{
- resultMap[ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams)] = index++;
+ float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams);
+ int classification = index++;
+
+ // Take the first class with each probability
+ // This avoids strange results when looping over batched results produced
+ // with identical test data.
+ std::map<float, int>::iterator lb = resultMap.lower_bound(prob);
+ if (lb == resultMap.end() ||
+ !resultMap.key_comp()(prob, lb->first)) {
+ // If the key is not already in the map, insert it.
+ resultMap.insert(lb, std::map<float, int>::value_type(prob, classification));
+ }
}
}
diff --git a/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
index 1313d2d01a..acaabe4487 100644
--- a/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
+++ b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
@@ -16,7 +16,6 @@ int main(int argc, char* argv[])
// Coverity fix: The following code may throw an exception of type std::length_error.
std::vector<ImageSet> imageSet =
{
- // Class number in probability print out offset by 1000 due to batch size fix
{"Dog.jpg", 669},
{"Cat.jpg", 669},
{"shark.jpg", 669},