ArmNN
 20.02
ClassifierTestCaseProvider< TDatabase, InferenceModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCaseProvider< TDatabase, InferenceModel >:
IInferenceTestCaseProvider

Public Member Functions

template<typename TConstructDatabaseCallable , typename TConstructModelCallable >
 ClassifierTestCaseProvider (TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
 
virtual void AddCommandLineOptions (boost::program_options::options_description &options) override
 
virtual bool ProcessCommandLineOptions (const InferenceTestOptions &commonOptions) override
 
virtual std::unique_ptr< IInferenceTestCaseGetTestCase (unsigned int testCaseId) override
 
virtual bool OnInferenceTestFinished () override
 
- Public Member Functions inherited from IInferenceTestCaseProvider
virtual ~IInferenceTestCaseProvider ()
 

Detailed Description

template<typename TDatabase, typename InferenceModel>
class armnn::test::ClassifierTestCaseProvider< TDatabase, InferenceModel >

Definition at line 177 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCaseProvider()

ClassifierTestCaseProvider ( TConstructDatabaseCallable  constructDatabase,
TConstructModelCallable  constructModel 
)

Definition at line 175 of file InferenceTest.inl.

177  : m_ConstructModel(constructModel)
178  , m_ConstructDatabase(constructDatabase)
179  , m_NumInferences(0)
180  , m_NumCorrectInferences(0)
181 {
182 }

Member Function Documentation

◆ AddCommandLineOptions()

void AddCommandLineOptions ( boost::program_options::options_description &  options)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 185 of file InferenceTest.inl.

References InferenceModel< IParser, TDataType >::AddCommandLineOptions().

187 {
188  namespace po = boost::program_options;
189 
190  options.add_options()
191  ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
192  "Reads expected predictions from the given file and confirms they match the actual predictions.")
193  ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
194  "Predictions are saved to the given file for later use via --validation-file-in.")
195  ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
196  "Path to directory containing test data");
197 
198  InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
199 }
static void AddCommandLineOptions(boost::program_options::options_description &desc, CommandLineOptions &options)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options

◆ GetTestCase()

std::unique_ptr< IInferenceTestCase > GetTestCase ( unsigned int  testCaseId)
overridevirtual

Implements IInferenceTestCaseProvider.

Definition at line 229 of file InferenceTest.inl.

230 {
231  std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
232  if (testCaseData == nullptr)
233  {
234  return nullptr;
235  }
236 
237  return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
238  m_NumInferences,
239  m_NumCorrectInferences,
240  m_ValidationPredictions,
241  m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
242  *m_Model,
243  testCaseId,
244  testCaseData->m_Label,
245  std::move(testCaseData->m_InputImage));
246 }

◆ OnInferenceTestFinished()

bool OnInferenceTestFinished ( )
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 249 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::info, and armnn::numeric_cast().

250 {
251  const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
252  boost::numeric_cast<double>(m_NumInferences);
253  ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
254 
255  // If a validation file was requested as output, the predictions are saved to it.
256  if (!m_ValidationFileOut.empty())
257  {
258  std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
259  if (validationFileOut.good())
260  {
261  for (const unsigned int prediction : m_ValidationPredictionsOut)
262  {
263  validationFileOut << prediction << std::endl;
264  }
265  }
266  else
267  {
268  ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
269  return false;
270  }
271  }
272 
273  return true;
274 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33

◆ ProcessCommandLineOptions()

bool ProcessCommandLineOptions ( const InferenceTestOptions commonOptions)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 202 of file InferenceTest.inl.

References armnn::test::ValidateDirectory().

204 {
205  if (!ValidateDirectory(m_DataDir))
206  {
207  return false;
208  }
209 
210  ReadPredictions();
211 
212  m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
213  if (!m_Model)
214  {
215  return false;
216  }
217 
218  m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
219  if (!m_Database)
220  {
221  return false;
222  }
223 
224  return true;
225 }
bool ValidateDirectory(std::string &dir)

The documentation for this class was generated from the following files: