ArmNN
 21.02
ClassifierTestCaseProvider< TDatabase, InferenceModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCaseProvider< TDatabase, InferenceModel >:
IInferenceTestCaseProvider

Public Member Functions

template<typename TConstructDatabaseCallable , typename TConstructModelCallable >
 ClassifierTestCaseProvider (TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
 
virtual void AddCommandLineOptions (cxxopts::Options &options, std::vector< std::string > &required) override
 
virtual bool ProcessCommandLineOptions (const InferenceTestOptions &commonOptions) override
 
virtual std::unique_ptr< IInferenceTestCaseGetTestCase (unsigned int testCaseId) override
 
virtual bool OnInferenceTestFinished () override
 
- Public Member Functions inherited from IInferenceTestCaseProvider
virtual ~IInferenceTestCaseProvider ()
 

Detailed Description

template<typename TDatabase, typename InferenceModel>
class armnn::test::ClassifierTestCaseProvider< TDatabase, InferenceModel >

Definition at line 178 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCaseProvider()

ClassifierTestCaseProvider ( TConstructDatabaseCallable  constructDatabase,
TConstructModelCallable  constructModel 
)

Definition at line 174 of file InferenceTest.inl.

176  : m_ConstructModel(constructModel)
177  , m_ConstructDatabase(constructDatabase)
178  , m_NumInferences(0)
179  , m_NumCorrectInferences(0)
180 {
181 }

Member Function Documentation

◆ AddCommandLineOptions()

void AddCommandLineOptions ( cxxopts::Options &  options,
std::vector< std::string > &  required 
)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 184 of file InferenceTest.inl.

References InferenceModel< IParser, TDataType >::AddCommandLineOptions().

186 {
187  options
188  .allow_unrecognised_options()
189  .add_options()
190  ("validation-file-in",
191  "Reads expected predictions from the given file and confirms they match the actual predictions.",
192  cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
193  ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
194  cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
195  ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
196 
197  required.emplace_back("data-dir"); //add to required arguments to check
198 
199  InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
200 }
static void AddCommandLineOptions(cxxopts::Options &options, CommandLineOptions &cLineOptions, std::vector< std::string > &required)

◆ GetTestCase()

std::unique_ptr< IInferenceTestCase > GetTestCase ( unsigned int  testCaseId)
overridevirtual

Implements IInferenceTestCaseProvider.

Definition at line 230 of file InferenceTest.inl.

231 {
232  std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
233  if (testCaseData == nullptr)
234  {
235  return nullptr;
236  }
237 
238  return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
239  m_NumInferences,
240  m_NumCorrectInferences,
241  m_ValidationPredictions,
242  m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
243  *m_Model,
244  testCaseId,
245  testCaseData->m_Label,
246  std::move(testCaseData->m_InputImage));
247 }

◆ OnInferenceTestFinished()

bool OnInferenceTestFinished ( )
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 250 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::info, and armnn::numeric_cast().

251 {
252  const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
253  armnn::numeric_cast<double>(m_NumInferences);
254  ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
255 
256  // If a validation file was requested as output, the predictions are saved to it.
257  if (!m_ValidationFileOut.empty())
258  {
259  std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
260  if (validationFileOut.good())
261  {
262  for (const unsigned int prediction : m_ValidationPredictionsOut)
263  {
264  validationFileOut << prediction << std::endl;
265  }
266  }
267  else
268  {
269  ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
270  return false;
271  }
272  }
273 
274  return true;
275 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:202
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35

◆ ProcessCommandLineOptions()

bool ProcessCommandLineOptions ( const InferenceTestOptions commonOptions)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 203 of file InferenceTest.inl.

References armnn::test::ValidateDirectory().

205 {
206  if (!ValidateDirectory(m_DataDir))
207  {
208  return false;
209  }
210 
211  ReadPredictions();
212 
213  m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
214  if (!m_Model)
215  {
216  return false;
217  }
218 
219  m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
220  if (!m_Database)
221  {
222  return false;
223  }
224 
225  return true;
226 }
bool ValidateDirectory(std::string &dir)

The documentation for this class was generated from the following files: