aboutsummaryrefslogtreecommitdiff
path: root/tests/MobileNetSsdInferenceTest.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/MobileNetSsdInferenceTest.hpp')
-rw-r--r--tests/MobileNetSsdInferenceTest.hpp69
1 files changed, 34 insertions, 35 deletions
diff --git a/tests/MobileNetSsdInferenceTest.hpp b/tests/MobileNetSsdInferenceTest.hpp
index 10ee1dcae6..bbbf957dcf 100644
--- a/tests/MobileNetSsdInferenceTest.hpp
+++ b/tests/MobileNetSsdInferenceTest.hpp
@@ -29,7 +29,7 @@ public:
{ std::move(testCaseData.m_InputData) },
{ k_OutputSize1, k_OutputSize2, k_OutputSize3, k_OutputSize4 })
, m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
- , m_DetectedObjects(testCaseData.m_ExpectedOutput)
+ , m_DetectedObjects(testCaseData.m_ExpectedDetectedObject)
{}
TestCaseResult ProcessResult(const InferenceTestOptions& options) override
@@ -46,10 +46,21 @@ public:
const std::vector<float>& output4 = boost::get<std::vector<float>>(this->GetOutputs()[3]); // valid detections
BOOST_ASSERT(output4.size() == k_OutputSize4);
+ const size_t numDetections = boost::numeric_cast<size_t>(output4[0]);
+
+ // Check if number of valid detections matches expectations
+ const size_t expectedNumDetections = m_DetectedObjects.size();
+ if (numDetections != expectedNumDetections)
+ {
+ BOOST_LOG_TRIVIAL(error) << "Number of detections is incorrect: Expected (" <<
+ expectedNumDetections << ")" << " but got (" << numDetections << ")";
+ return TestCaseResult::Failed;
+ }
+
// Extract detected objects from output data
std::vector<DetectedObject> detectedObjects;
const float* outputData = output1.data();
- for (unsigned int i = 0u; i < k_NumDetections; i++)
+ for (unsigned int i = 0u; i < numDetections; i++)
{
// NOTE: Order of coordinates in output data is yMin, xMin, yMax, xMax
float yMin = *outputData++;
@@ -58,61 +69,49 @@ public:
float xMax = *outputData++;
DetectedObject detectedObject(
- static_cast<unsigned int>(output2.at(i)),
+ output2.at(i),
BoundingBox(xMin, yMin, xMax, yMax),
output3.at(i));
detectedObjects.push_back(detectedObject);
}
- // Sort detected objects by confidence
- std::sort(detectedObjects.begin(), detectedObjects.end(),
- [](const DetectedObject& a, const DetectedObject& b)
- {
- return a.m_Confidence > b.m_Confidence ||
- (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
- });
-
- // Check if number of valid detections matches expectations
- const size_t numValidDetections = boost::numeric_cast<size_t>(output4[0]);
- if (numValidDetections != m_DetectedObjects.size())
- {
- BOOST_LOG_TRIVIAL(error) << "Number of valid detections is incorrect: Expected (" <<
- m_DetectedObjects.size() << ")" << " but got (" << numValidDetections << ")";
- return TestCaseResult::Failed;
- }
+ std::sort(detectedObjects.begin(), detectedObjects.end());
+ std::sort(m_DetectedObjects.begin(), m_DetectedObjects.end());
// Compare detected objects with expected results
std::vector<DetectedObject>::const_iterator it = detectedObjects.begin();
- for (const DetectedObject& expectedDetection : m_DetectedObjects)
+ for (unsigned int i = 0; i < numDetections; i++)
{
if (it == detectedObjects.end())
{
- BOOST_LOG_TRIVIAL(info) << "No more detected objects to compare";
+ BOOST_LOG_TRIVIAL(error) << "No more detected objects found! Index out of bounds: " << i;
return TestCaseResult::Abort;
}
const DetectedObject& detectedObject = *it;
- if (detectedObject.m_Class != expectedDetection.m_Class)
+ const DetectedObject& expectedObject = m_DetectedObjects[i];
+
+ if (detectedObject.m_Class != expectedObject.m_Class)
{
BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << this->GetTestCaseId() <<
- " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
+ " is incorrect: Expected (" << expectedObject.m_Class << ")" <<
" but predicted (" << detectedObject.m_Class << ")";
return TestCaseResult::Failed;
}
- if(!m_FloatComparer(detectedObject.m_Confidence, expectedDetection.m_Confidence))
+ if(!m_FloatComparer(detectedObject.m_Confidence, expectedObject.m_Confidence))
{
BOOST_LOG_TRIVIAL(error) << "Confidence of prediction for test case " << this->GetTestCaseId() <<
- " is incorrect: Expected (" << expectedDetection.m_Confidence << ") +- 1.0 pc" <<
+ " is incorrect: Expected (" << expectedObject.m_Confidence << ") +- 1.0 pc" <<
" but predicted (" << detectedObject.m_Confidence << ")";
return TestCaseResult::Failed;
}
- if (!m_FloatComparer(detectedObject.m_BoundingBox.m_XMin, expectedDetection.m_BoundingBox.m_XMin) ||
- !m_FloatComparer(detectedObject.m_BoundingBox.m_YMin, expectedDetection.m_BoundingBox.m_YMin) ||
- !m_FloatComparer(detectedObject.m_BoundingBox.m_XMax, expectedDetection.m_BoundingBox.m_XMax) ||
- !m_FloatComparer(detectedObject.m_BoundingBox.m_YMax, expectedDetection.m_BoundingBox.m_YMax))
+ if (!m_FloatComparer(detectedObject.m_BoundingBox.m_XMin, expectedObject.m_BoundingBox.m_XMin) ||
+ !m_FloatComparer(detectedObject.m_BoundingBox.m_YMin, expectedObject.m_BoundingBox.m_YMin) ||
+ !m_FloatComparer(detectedObject.m_BoundingBox.m_XMax, expectedObject.m_BoundingBox.m_XMax) ||
+ !m_FloatComparer(detectedObject.m_BoundingBox.m_YMax, expectedObject.m_BoundingBox.m_YMax))
{
BOOST_LOG_TRIVIAL(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
" is incorrect";
@@ -126,11 +125,11 @@ public:
}
private:
- static constexpr unsigned int k_NumDetections = 1u;
+ static constexpr unsigned int k_Shape = 10u;
- static constexpr unsigned int k_OutputSize1 = k_NumDetections * 4u;
- static constexpr unsigned int k_OutputSize2 = k_NumDetections;
- static constexpr unsigned int k_OutputSize3 = k_NumDetections;
+ static constexpr unsigned int k_OutputSize1 = k_Shape * 4u;
+ static constexpr unsigned int k_OutputSize2 = k_Shape;
+ static constexpr unsigned int k_OutputSize3 = k_Shape;
static constexpr unsigned int k_OutputSize4 = 1u;
boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
@@ -152,7 +151,7 @@ public:
options.add_options()
("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
- "Path to directory containing test data");
+ "Path to directory containing test data");
Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
}
@@ -169,7 +168,7 @@ public:
{
return false;
}
- std::pair<float, int32_t> qParams = m_Model->GetQuantizationParams();
+ std::pair<float, int32_t> qParams = m_Model->GetInputQuantizationParams();
m_Database = std::make_unique<MobileNetSsdDatabase>(m_DataDir.c_str(), qParams.first, qParams.second);
if (!m_Database)
{