ArmNN
 20.05
ClassifierTestCaseProvider< TDatabase, InferenceModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCaseProvider< TDatabase, InferenceModel >:
IInferenceTestCaseProvider

Public Member Functions

template<typename TConstructDatabaseCallable , typename TConstructModelCallable >
 ClassifierTestCaseProvider (TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
 
virtual void AddCommandLineOptions (boost::program_options::options_description &options) override
 
virtual bool ProcessCommandLineOptions (const InferenceTestOptions &commonOptions) override
 
virtual std::unique_ptr< IInferenceTestCaseGetTestCase (unsigned int testCaseId) override
 
virtual bool OnInferenceTestFinished () override
 
- Public Member Functions inherited from IInferenceTestCaseProvider
virtual ~IInferenceTestCaseProvider ()
 

Detailed Description

template<typename TDatabase, typename InferenceModel>
class armnn::test::ClassifierTestCaseProvider< TDatabase, InferenceModel >

Definition at line 177 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCaseProvider()

ClassifierTestCaseProvider ( TConstructDatabaseCallable  constructDatabase,
TConstructModelCallable  constructModel 
)

Definition at line 174 of file InferenceTest.inl.

176  : m_ConstructModel(constructModel)
177  , m_ConstructDatabase(constructDatabase)
178  , m_NumInferences(0)
179  , m_NumCorrectInferences(0)
180 {
181 }

Member Function Documentation

◆ AddCommandLineOptions()

void AddCommandLineOptions ( boost::program_options::options_description &  options)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 184 of file InferenceTest.inl.

References InferenceModel< IParser, TDataType >::AddCommandLineOptions().

186 {
187  namespace po = boost::program_options;
188 
189  options.add_options()
190  ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
191  "Reads expected predictions from the given file and confirms they match the actual predictions.")
192  ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
193  "Predictions are saved to the given file for later use via --validation-file-in.")
194  ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
195  "Path to directory containing test data");
196 
197  InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
198 }
static void AddCommandLineOptions(boost::program_options::options_description &desc, CommandLineOptions &options)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options

◆ GetTestCase()

std::unique_ptr< IInferenceTestCase > GetTestCase ( unsigned int  testCaseId)
overridevirtual

Implements IInferenceTestCaseProvider.

Definition at line 228 of file InferenceTest.inl.

229 {
230  std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
231  if (testCaseData == nullptr)
232  {
233  return nullptr;
234  }
235 
236  return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
237  m_NumInferences,
238  m_NumCorrectInferences,
239  m_ValidationPredictions,
240  m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
241  *m_Model,
242  testCaseId,
243  testCaseData->m_Label,
244  std::move(testCaseData->m_InputImage));
245 }

◆ OnInferenceTestFinished()

bool OnInferenceTestFinished ( )
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 248 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::info, and armnn::numeric_cast().

249 {
250  const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
251  boost::numeric_cast<double>(m_NumInferences);
252  ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
253 
254  // If a validation file was requested as output, the predictions are saved to it.
255  if (!m_ValidationFileOut.empty())
256  {
257  std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
258  if (validationFileOut.good())
259  {
260  for (const unsigned int prediction : m_ValidationPredictionsOut)
261  {
262  validationFileOut << prediction << std::endl;
263  }
264  }
265  else
266  {
267  ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
268  return false;
269  }
270  }
271 
272  return true;
273 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33

◆ ProcessCommandLineOptions()

bool ProcessCommandLineOptions ( const InferenceTestOptions commonOptions)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 201 of file InferenceTest.inl.

References armnn::test::ValidateDirectory().

203 {
204  if (!ValidateDirectory(m_DataDir))
205  {
206  return false;
207  }
208 
209  ReadPredictions();
210 
211  m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
212  if (!m_Model)
213  {
214  return false;
215  }
216 
217  m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
218  if (!m_Database)
219  {
220  return false;
221  }
222 
223  return true;
224 }
bool ValidateDirectory(std::string &dir)

The documentation for this class was generated from the following files: