aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /tests/InferenceTest.cpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'tests/InferenceTest.cpp')
-rw-r--r--tests/InferenceTest.cpp236
1 files changed, 236 insertions, 0 deletions
diff --git a/tests/InferenceTest.cpp b/tests/InferenceTest.cpp
new file mode 100644
index 0000000000..55616798e2
--- /dev/null
+++ b/tests/InferenceTest.cpp
@@ -0,0 +1,236 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#include "InferenceTest.hpp"
+
+#include <boost/algorithm/string.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+#include <boost/log/trivial.hpp>
+#include <boost/filesystem/path.hpp>
+#include <boost/assert.hpp>
+#include <boost/format.hpp>
+#include <boost/program_options.hpp>
+#include <boost/filesystem/operations.hpp>
+
+#include <fstream>
+#include <iostream>
+#include <iomanip>
+#include <array>
+
+using namespace std;
+using namespace std::chrono;
+using namespace armnn::test;
+
+namespace armnn
+{
+namespace test
+{
+
+/// Parse the command line of an ArmNN (or referencetests) inference test program.
+/// \return false if any error occurred during options processing, otherwise true
+bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
+ InferenceTestOptions& outParams)
+{
+ namespace po = boost::program_options;
+
+ std::string computeDeviceStr;
+
+ po::options_description desc("Options");
+
+ try
+ {
+ // Add generic options needed for all inference tests
+ desc.add_options()
+ ("help", "Display help messages")
+ ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
+ "Sets the number number of inferences to perform. If unset, a default number will be ran.")
+ ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
+ "If non-empty, each individual inference time will be recorded and output to this file");
+
+ // Add options specific to the ITestCaseProvider
+ testCaseProvider.AddCommandLineOptions(desc);
+ }
+ catch (const std::exception& e)
+ {
+ // Coverity points out that default_value(...) can throw a bad_lexical_cast,
+ // and that desc.add_options() can throw boost::io::too_few_args.
+ // They really won't in any of these cases.
+ BOOST_ASSERT_MSG(false, "Caught unexpected exception");
+ std::cerr << "Fatal internal error: " << e.what() << std::endl;
+ return false;
+ }
+
+ po::variables_map vm;
+
+ try
+ {
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+
+ if (vm.count("help"))
+ {
+ std::cout << desc << std::endl;
+ return false;
+ }
+
+ po::notify(vm);
+ }
+ catch (po::error& e)
+ {
+ std::cerr << e.what() << std::endl << std::endl;
+ std::cerr << desc << std::endl;
+ return false;
+ }
+
+ if (!testCaseProvider.ProcessCommandLineOptions())
+ {
+ return false;
+ }
+
+ return true;
+}
+
+bool ValidateDirectory(std::string& dir)
+{
+ if (dir[dir.length() - 1] != '/')
+ {
+ dir += "/";
+ }
+
+ if (!boost::filesystem::exists(dir))
+ {
+ std::cerr << "Given directory " << dir << " does not exist" << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+bool InferenceTest(const InferenceTestOptions& params,
+ const std::vector<unsigned int>& defaultTestCaseIds,
+ IInferenceTestCaseProvider& testCaseProvider)
+{
+#if !defined (NDEBUG)
+ if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn
+ {
+ BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
+ }
+#endif
+
+ double totalTime = 0;
+ unsigned int nbProcessed = 0;
+ bool success = true;
+
+ // Open the file to write inference times to, if needed
+ ofstream inferenceTimesFile;
+ const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
+ if (recordInferenceTimes)
+ {
+ inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
+ if (!inferenceTimesFile.good())
+ {
+ BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
+ << params.m_InferenceTimesFile;
+ return false;
+ }
+ }
+
+ // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
+ std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
+ if (warmupTestCase == nullptr)
+ {
+ BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
+ return false;
+ }
+
+ try
+ {
+ warmupTestCase->Run();
+ }
+ catch (const TestFrameworkException& testError)
+ {
+ BOOST_LOG_TRIVIAL(error) << testError.what();
+ return false;
+ }
+
+ const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
+ : boost::numeric_cast<unsigned int>(defaultTestCaseIds.size());
+
+ for (; nbProcessed < nbTotalToProcess; nbProcessed++)
+ {
+ const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
+ std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
+
+ if (testCase == nullptr)
+ {
+ BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
+ return false;
+ }
+
+ time_point<high_resolution_clock> predictStart;
+ time_point<high_resolution_clock> predictEnd;
+
+ TestCaseResult result = TestCaseResult::Ok;
+
+ try
+ {
+ predictStart = high_resolution_clock::now();
+
+ testCase->Run();
+
+ predictEnd = high_resolution_clock::now();
+
+ // duration<double> will convert the time difference into seconds as a double by default.
+ double timeTakenS = duration<double>(predictEnd - predictStart).count();
+ totalTime += timeTakenS;
+
+ // Output inference times if needed
+ if (recordInferenceTimes)
+ {
+ inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
+ }
+
+ result = testCase->ProcessResult(params);
+
+ }
+ catch (const TestFrameworkException& testError)
+ {
+ BOOST_LOG_TRIVIAL(error) << testError.what();
+ result = TestCaseResult::Abort;
+ }
+
+ switch (result)
+ {
+ case TestCaseResult::Ok:
+ break;
+ case TestCaseResult::Abort:
+ return false;
+ case TestCaseResult::Failed:
+ // This test failed so we will fail the entire program eventually, but keep going for now.
+ success = false;
+ break;
+ default:
+ BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
+ return false;
+ }
+ }
+
+ const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
+
+ BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
+ "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
+ BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
+ "Average time per test case: " << averageTimePerTestCaseMs << " ms";
+
+ if (!success)
+ {
+ BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
+ return false;
+ }
+
+ return testCaseProvider.OnInferenceTestFinished();
+}
+
+} // namespace test
+
+} // namespace armnn