aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-05-16 16:33:00 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-20 11:53:01 +0000
commitac73760a3731934ff7401d847eb2db7b9a77be02 (patch)
tree3a77741a67e10586d8fc48e14e587ebbd5788315
parent58ef2c6f797f6bdb962016c519ebbc980ec2ed50 (diff)
downloadarmnn-ac73760a3731934ff7401d847eb2db7b9a77be02.tar.gz
IVGCVSW-3060 Classification tests display output value as raw float
Change-Id: I92a1e043d60fa2fe3414dc9339ef36204aca42e2 Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r--tests/InferenceTest.hpp47
-rw-r--r--tests/InferenceTest.inl99
2 files changed, 68 insertions, 78 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp
index 3ebfdbcc3c..40c9e5e597 100644
--- a/tests/InferenceTest.hpp
+++ b/tests/InferenceTest.hpp
@@ -136,53 +136,6 @@ private:
std::vector<TContainer> m_Outputs;
};
-template <typename TDataType>
-struct ToFloat { }; // nothing defined for the generic case
-
-template <>
-struct ToFloat<float>
-{
- static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
- {
- // assuming that float models are not quantized
- return value;
- }
-
- static inline float Convert(int value, const InferenceModelInternal::QuantizationParams &)
- {
- // assuming that float models are not quantized
- return static_cast<float>(value);
- }
-};
-
-template <>
-struct ToFloat<uint8_t>
-{
- static inline float Convert(uint8_t value,
- const InferenceModelInternal::QuantizationParams & quantizationParams)
- {
- return armnn::Dequantize<uint8_t>(value,
- quantizationParams.first,
- quantizationParams.second);
- }
-
- static inline float Convert(int value,
- const InferenceModelInternal::QuantizationParams & quantizationParams)
- {
- return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
- quantizationParams.first,
- quantizationParams.second);
- }
-
- static inline float Convert(float value,
- const InferenceModelInternal::QuantizationParams & quantizationParams)
- {
- return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
- quantizationParams.first,
- quantizationParams.second);
- }
-};
-
template <typename TTestCaseDatabase, typename TModel>
class ClassifierTestCase : public InferenceModelTestCase<TModel>
{
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl
index 0112037bc3..04cae99132 100644
--- a/tests/InferenceTest.inl
+++ b/tests/InferenceTest.inl
@@ -51,48 +51,85 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
{
}
-template <typename TTestCaseDatabase, typename TModel>
-TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
+struct ClassifierResultProcessor : public boost::static_visitor<>
{
- auto& output = this->GetOutputs()[0];
- const auto testCaseId = this->GetTestCaseId();
+ using ResultMap = std::map<float,int>;
- std::map<float,int> resultMap;
+ ClassifierResultProcessor(float scale, int offset)
+ : m_Scale(scale)
+ , m_Offset(offset)
+ {}
+
+ void operator()(const std::vector<float>& values)
{
- int index = 0;
+ SortPredictions(values, [](float value)
+ {
+ return value;
+ });
+ }
- boost::apply_visitor([&](auto&& value)
- {
- for (const auto & o : value)
- {
- float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams);
- int classification = index++;
-
- // Take the first class with each probability
- // This avoids strange results when looping over batched results produced
- // with identical test data.
- std::map<float, int>::iterator lb = resultMap.lower_bound(prob);
- if (lb == resultMap.end() ||
- !resultMap.key_comp()(prob, lb->first)) {
- // If the key is not already in the map, insert it.
- resultMap.insert(lb, std::map<float, int>::value_type(prob, classification));
- }
- }
- },
- output);
+ void operator()(const std::vector<uint8_t>& values)
+ {
+ auto& scale = m_Scale;
+ auto& offset = m_Offset;
+ SortPredictions(values, [&scale, &offset](uint8_t value)
+ {
+ return armnn::Dequantize(value, scale, offset);
+ });
}
+ void operator()(const std::vector<int>& values)
{
- BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
- auto it = resultMap.rbegin();
- for (int i=0; i<5 && it != resultMap.rend(); ++i)
+ BOOST_ASSERT_MSG(false, "Non-float predictions output not supported.");
+ }
+
+ ResultMap& GetResultMap() { return m_ResultMap; }
+
+private:
+ template<typename Container, typename Delegate>
+ void SortPredictions(const Container& c, Delegate delegate)
+ {
+ int index = 0;
+ for (const auto& value : c)
{
- BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
- " with confidence: " << 100.0*(it->first) << "%";
- ++it;
+ int classification = index++;
+ // Take the first class with each probability
+ // This avoids strange results when looping over batched results produced
+ // with identical test data.
+ ResultMap::iterator lb = m_ResultMap.lower_bound(value);
+
+ if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
+ {
+ // If the key is not already in the map, insert it.
+ m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
+ }
}
}
+ ResultMap m_ResultMap;
+
+ float m_Scale=0.0f;
+ int m_Offset=0;
+};
+
+template <typename TTestCaseDatabase, typename TModel>
+TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
+{
+ auto& output = this->GetOutputs()[0];
+ const auto testCaseId = this->GetTestCaseId();
+
+ ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
+ boost::apply_visitor(resultProcessor, output);
+
+ BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
+ auto it = resultProcessor.GetResultMap().rbegin();
+ for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
+ {
+ BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
+ " with value: " << (it->first);
+ ++it;
+ }
+
unsigned int prediction = 0;
boost::apply_visitor([&](auto&& value)
{