ArmNN
 20.08
ClassifierTestCase< TTestCaseDatabase, TModel > Class Template Reference

#include <InferenceTest.hpp>

Inheritance diagram for ClassifierTestCase< TTestCaseDatabase, TModel >:
InferenceModelTestCase< TModel > IInferenceTestCase

Public Member Functions

 ClassifierTestCase (int &numInferencesRef, int &numCorrectInferencesRef, const std::vector< unsigned int > &validationPredictions, std::vector< unsigned int > *validationPredictionsOut, TModel &model, unsigned int testCaseId, unsigned int label, std::vector< typename TModel::DataType > modelInput)
 
virtual TestCaseResult ProcessResult (const InferenceTestOptions &params) override
 
- Public Member Functions inherited from InferenceModelTestCase< TModel >
 InferenceModelTestCase (TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
 
virtual void Run () override
 
- Public Member Functions inherited from IInferenceTestCase
virtual ~IInferenceTestCase ()
 

Additional Inherited Members

- Public Types inherited from InferenceModelTestCase< TModel >
using TContainer = boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > >
 
- Protected Member Functions inherited from InferenceModelTestCase< TModel >
unsigned int GetTestCaseId () const
 
const std::vector< TContainer > & GetOutputs () const
 

Detailed Description

template<typename TTestCaseDatabase, typename TModel>
class armnn::test::ClassifierTestCase< TTestCaseDatabase, TModel >

Definition at line 149 of file InferenceTest.hpp.

Constructor & Destructor Documentation

◆ ClassifierTestCase()

ClassifierTestCase ( int &  numInferencesRef,
int &  numCorrectInferencesRef,
const std::vector< unsigned int > &  validationPredictions,
std::vector< unsigned int > *  validationPredictionsOut,
TModel &  model,
unsigned int  testCaseId,
unsigned int  label,
std::vector< typename TModel::DataType >  modelInput 
)

Definition at line 30 of file InferenceTest.inl.

References ARMNN_ASSERT_MSG, armnn::Dequantize, and armnn::IgnoreUnused().

40  model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
41  , m_Label(label)
42  , m_QuantizationParams(model.GetQuantizationParams())
43  , m_NumInferencesRef(numInferencesRef)
44  , m_NumCorrectInferencesRef(numCorrectInferencesRef)
45  , m_ValidationPredictions(validationPredictions)
46  , m_ValidationPredictionsOut(validationPredictionsOut)
47 {
48 }

Member Function Documentation

◆ ProcessResult()

TestCaseResult ProcessResult ( const InferenceTestOptions params)
overridevirtual

Implements IInferenceTestCase.

Definition at line 113 of file InferenceTest.inl.

References ARMNN_LOG, armnn::error, armnn::test::Failed, InferenceModelTestCase< TModel >::GetOutputs(), InferenceModelTestCase< TModel >::GetTestCaseId(), armnn::info, InferenceTestOptions::m_IterationCount, armnn::numeric_cast(), and armnn::test::Ok.

114 {
115  auto& output = this->GetOutputs()[0];
116  const auto testCaseId = this->GetTestCaseId();
117 
118  ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
119  boost::apply_visitor(resultProcessor, output);
120 
121  ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
122  auto it = resultProcessor.GetResultMap().rbegin();
123  for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
124  {
125  ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
126  " with value: " << (it->first);
127  ++it;
128  }
129 
130  unsigned int prediction = 0;
131  boost::apply_visitor([&](auto&& value)
132  {
133  prediction = boost::numeric_cast<unsigned int>(
134  std::distance(value.begin(), std::max_element(value.begin(), value.end())));
135  },
136  output);
137 
138  // If we're just running the defaultTestCaseIds, each one must be classified correctly.
139  if (params.m_IterationCount == 0 && prediction != m_Label)
140  {
141  ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
142  " is incorrect (should be " << m_Label << ")";
143  return TestCaseResult::Failed;
144  }
145 
146  // If a validation file was provided as input, it checks that the prediction matches.
147  if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
148  {
149  ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
150  " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
151  return TestCaseResult::Failed;
152  }
153 
154  // If a validation file was requested as output, it stores the predictions.
155  if (m_ValidationPredictionsOut)
156  {
157  m_ValidationPredictionsOut->push_back(prediction);
158  }
159 
160  // Updates accuracy stats.
161  m_NumInferencesRef++;
162  if (prediction == m_Label)
163  {
164  m_NumCorrectInferencesRef++;
165  }
166 
167  return TestCaseResult::Ok;
168 }
const std::vector< TContainer > & GetOutputs() const
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33
The test completed without any errors.

The documentation for this class was generated from the following files: