From a4247d5a50502811a6956dffd990c0254622b7e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89anna=20=C3=93=20Cath=C3=A1in?= Date: Wed, 8 May 2019 14:00:45 +0100 Subject: IVGCVSW-2900 Adding the Accuracy Checker Tool and tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I4ac325e45f2236b8e0757d21046f117024ce3979 Signed-off-by: Éanna Ó Catháin --- src/armnn/test/ModelAccuracyCheckerTest.cpp | 98 +++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/armnn/test/ModelAccuracyCheckerTest.cpp (limited to 'src/armnn/test') diff --git a/src/armnn/test/ModelAccuracyCheckerTest.cpp b/src/armnn/test/ModelAccuracyCheckerTest.cpp new file mode 100644 index 0000000000..f3a6c9d81d --- /dev/null +++ b/src/armnn/test/ModelAccuracyCheckerTest.cpp @@ -0,0 +1,98 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "ModelAccuracyChecker.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace armnnUtils; + +struct TestHelper { + const std::map GetValidationLabelSet() + { + std::map validationLabelSet; + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000001", 2)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000002", 9)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000003", 1)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000004", 6)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000005", 5)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000006", 0)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000007", 8)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000008", 4)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 3)); + validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 7)); + + return validationLabelSet; + } +}; + +BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest) + +using TContainer = boost::variant, std::vector, std::vector>; + +BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper) +{ + ModelAccuracyChecker checker(GetValidationLabelSet()); + + // Add image 1 and check accuracy + std::vector inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + TContainer inference1Container(inferenceOutputVector1); + std::vector outputTensor1; + outputTensor1.push_back(inference1Container); + + std::string imageName = "ILSVRC2012_val_00000001.JPEG"; + checker.AddImageResult(imageName, outputTensor1); + + // Top 1 Accuracy + float totalAccuracy = checker.GetAccuracy(1); + BOOST_CHECK(totalAccuracy == 100.0f); + + // Add image 2 and check accuracy + std::vector inferenceOutputVector2 = {0.10f, 0.0f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f}; + TContainer inference2Container(inferenceOutputVector2); + std::vector outputTensor2; + outputTensor2.push_back(inference2Container); + + imageName = "ILSVRC2012_val_00000002.JPEG"; + checker.AddImageResult(imageName, outputTensor2); + + // Top 1 Accuracy + totalAccuracy = checker.GetAccuracy(1); + BOOST_CHECK(totalAccuracy == 50.0f); + + // Top 2 Accuracy + totalAccuracy = checker.GetAccuracy(2); + BOOST_CHECK(totalAccuracy == 100.0f); + + // Add image 3 and check accuracy + std::vector inferenceOutputVector3 = {0.0f, 0.10f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f}; + TContainer inference3Container(inferenceOutputVector3); + std::vector outputTensor3; + outputTensor3.push_back(inference3Container); + + imageName = "ILSVRC2012_val_00000003.JPEG"; + checker.AddImageResult(imageName, outputTensor3); + + // Top 1 Accuracy + totalAccuracy = checker.GetAccuracy(1); + BOOST_CHECK(totalAccuracy == 33.3333321f); + + // Top 2 Accuracy + totalAccuracy = checker.GetAccuracy(2); + BOOST_CHECK(totalAccuracy == 66.6666641f); + + // Top 3 Accuracy + totalAccuracy = checker.GetAccuracy(3); + BOOST_CHECK(totalAccuracy == 100.0f); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1