ArmNN
 20.08
ClassifierTestCaseProvider< TDatabase, InferenceModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCaseProvider< TDatabase, InferenceModel >:
IInferenceTestCaseProvider

Public Member Functions

template<typename TConstructDatabaseCallable , typename TConstructModelCallable >
 ClassifierTestCaseProvider (TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
 
virtual void AddCommandLineOptions (boost::program_options::options_description &options) override
 
virtual bool ProcessCommandLineOptions (const InferenceTestOptions &commonOptions) override
 
virtual std::unique_ptr< IInferenceTestCaseGetTestCase (unsigned int testCaseId) override
 
virtual bool OnInferenceTestFinished () override
 
- Public Member Functions inherited from IInferenceTestCaseProvider
virtual ~IInferenceTestCaseProvider ()
 

Detailed Description

template<typename TDatabase, typename InferenceModel>
class armnn::test::ClassifierTestCaseProvider< TDatabase, InferenceModel >

Definition at line 177 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCaseProvider()

ClassifierTestCaseProvider ( TConstructDatabaseCallable  constructDatabase,
TConstructModelCallable  constructModel 
)

Definition at line 172 of file InferenceTest.inl.

174  : m_ConstructModel(constructModel)
175  , m_ConstructDatabase(constructDatabase)
176  , m_NumInferences(0)
177  , m_NumCorrectInferences(0)
178 {
179 }

Member Function Documentation

◆ AddCommandLineOptions()

void AddCommandLineOptions ( boost::program_options::options_description &  options)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 182 of file InferenceTest.inl.

References InferenceModel< IParser, TDataType >::AddCommandLineOptions().

184 {
185  namespace po = boost::program_options;
186 
187  options.add_options()
188  ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
189  "Reads expected predictions from the given file and confirms they match the actual predictions.")
190  ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
191  "Predictions are saved to the given file for later use via --validation-file-in.")
192  ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
193  "Path to directory containing test data");
194 
195  InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
196 }
static void AddCommandLineOptions(boost::program_options::options_description &desc, CommandLineOptions &options)

◆ GetTestCase()

std::unique_ptr< IInferenceTestCase > GetTestCase ( unsigned int  testCaseId)
overridevirtual

Implements IInferenceTestCaseProvider.

Definition at line 226 of file InferenceTest.inl.

227 {
228  std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
229  if (testCaseData == nullptr)
230  {
231  return nullptr;
232  }
233 
234  return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
235  m_NumInferences,
236  m_NumCorrectInferences,
237  m_ValidationPredictions,
238  m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
239  *m_Model,
240  testCaseId,
241  testCaseData->m_Label,
242  std::move(testCaseData->m_InputImage));
243 }

◆ OnInferenceTestFinished()

bool OnInferenceTestFinished ( )
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 246 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::info, and armnn::numeric_cast().

247 {
248  const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
249  boost::numeric_cast<double>(m_NumInferences);
250  ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
251 
252  // If a validation file was requested as output, the predictions are saved to it.
253  if (!m_ValidationFileOut.empty())
254  {
255  std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
256  if (validationFileOut.good())
257  {
258  for (const unsigned int prediction : m_ValidationPredictionsOut)
259  {
260  validationFileOut << prediction << std::endl;
261  }
262  }
263  else
264  {
265  ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
266  return false;
267  }
268  }
269 
270  return true;
271 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33

◆ ProcessCommandLineOptions()

bool ProcessCommandLineOptions ( const InferenceTestOptions commonOptions)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 199 of file InferenceTest.inl.

References armnn::test::ValidateDirectory().

201 {
202  if (!ValidateDirectory(m_DataDir))
203  {
204  return false;
205  }
206 
207  ReadPredictions();
208 
209  m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
210  if (!m_Model)
211  {
212  return false;
213  }
214 
215  m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
216  if (!m_Database)
217  {
218  return false;
219  }
220 
221  return true;
222 }
bool ValidateDirectory(std::string &dir)

The documentation for this class was generated from the following files: