diff options
Diffstat (limited to 'tests/InferenceTest.inl')
-rw-r--r-- | tests/InferenceTest.inl | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl index 5e858f06d3..7ce017c6cd 100644 --- a/tests/InferenceTest.inl +++ b/tests/InferenceTest.inl @@ -60,7 +60,18 @@ TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(cons int index = 0; for (const auto & o : output) { - resultMap[ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams)] = index++; + float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams); + int classification = index++; + + // Take the first class with each probability + // This avoids strange results when looping over batched results produced + // with identical test data. + std::map<float, int>::iterator lb = resultMap.lower_bound(prob); + if (lb == resultMap.end() || + !resultMap.key_comp()(prob, lb->first)) { + // If the key is not already in the map, insert it. + resultMap.insert(lb, std::map<float, int>::value_type(prob, classification)); + } } } |