ArmNN
 22.05
ClassifierTestCaseProvider< TDatabase, InferenceModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCaseProvider< TDatabase, InferenceModel >:
IInferenceTestCaseProvider

Public Member Functions

template<typename TConstructDatabaseCallable , typename TConstructModelCallable >
 ClassifierTestCaseProvider (TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
 
virtual void AddCommandLineOptions (cxxopts::Options &options, std::vector< std::string > &required) override
 
virtual bool ProcessCommandLineOptions (const InferenceTestOptions &commonOptions) override
 
virtual std::unique_ptr< IInferenceTestCaseGetTestCase (unsigned int testCaseId) override
 
virtual bool OnInferenceTestFinished () override
 
- Public Member Functions inherited from IInferenceTestCaseProvider
virtual ~IInferenceTestCaseProvider ()
 

Detailed Description

template<typename TDatabase, typename InferenceModel>
class armnn::test::ClassifierTestCaseProvider< TDatabase, InferenceModel >

Definition at line 179 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCaseProvider()

ClassifierTestCaseProvider ( TConstructDatabaseCallable  constructDatabase,
TConstructModelCallable  constructModel 
)

Definition at line 183 of file InferenceTest.inl.

185  : m_ConstructModel(constructModel)
186  , m_ConstructDatabase(constructDatabase)
187  , m_NumInferences(0)
188  , m_NumCorrectInferences(0)
189 {
190 }

Member Function Documentation

◆ AddCommandLineOptions()

void AddCommandLineOptions ( cxxopts::Options &  options,
std::vector< std::string > &  required 
)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 193 of file InferenceTest.inl.

References InferenceModel< IParser, TDataType >::AddCommandLineOptions().

195 {
196  options
197  .allow_unrecognised_options()
198  .add_options()
199  ("validation-file-in",
200  "Reads expected predictions from the given file and confirms they match the actual predictions.",
201  cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
202  ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
203  cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
204  ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
205 
206  required.emplace_back("data-dir"); //add to required arguments to check
207 
208  InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
209 }
static void AddCommandLineOptions(cxxopts::Options &options, CommandLineOptions &cLineOptions, std::vector< std::string > &required)

◆ GetTestCase()

std::unique_ptr< IInferenceTestCase > GetTestCase ( unsigned int  testCaseId)
overridevirtual

Implements IInferenceTestCaseProvider.

Definition at line 239 of file InferenceTest.inl.

240 {
241  std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
242  if (testCaseData == nullptr)
243  {
244  return nullptr;
245  }
246 
247  return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
248  m_NumInferences,
249  m_NumCorrectInferences,
250  m_ValidationPredictions,
251  m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
252  *m_Model,
253  testCaseId,
254  testCaseData->m_Label,
255  std::move(testCaseData->m_InputImage));
256 }

◆ OnInferenceTestFinished()

bool OnInferenceTestFinished ( )
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 259 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::info, and armnn::numeric_cast().

260 {
261  const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
262  armnn::numeric_cast<double>(m_NumInferences);
263  ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
264 
265  // If a validation file was requested as output, the predictions are saved to it.
266  if (!m_ValidationFileOut.empty())
267  {
268  std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
269  if (validationFileOut.good())
270  {
271  for (const unsigned int prediction : m_ValidationPredictionsOut)
272  {
273  validationFileOut << prediction << std::endl;
274  }
275  }
276  else
277  {
278  ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
279  return false;
280  }
281  }
282 
283  return true;
284 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35

◆ ProcessCommandLineOptions()

bool ProcessCommandLineOptions ( const InferenceTestOptions commonOptions)
overridevirtual

Reimplemented from IInferenceTestCaseProvider.

Definition at line 212 of file InferenceTest.inl.

References armnn::test::ValidateDirectory().

214 {
215  if (!ValidateDirectory(m_DataDir))
216  {
217  return false;
218  }
219 
220  ReadPredictions();
221 
222  m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
223  if (!m_Model)
224  {
225  return false;
226  }
227 
228  m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
229  if (!m_Database)
230  {
231  return false;
232  }
233 
234  return true;
235 }
bool ValidateDirectory(std::string &dir)

The documentation for this class was generated from the following files: