summaryrefslogtreecommitdiff
path: root/source/application/main/Classifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/application/main/Classifier.cc')
-rw-r--r--source/application/main/Classifier.cc117
1 files changed, 46 insertions, 71 deletions
diff --git a/source/application/main/Classifier.cc b/source/application/main/Classifier.cc
index bc2c378..9a47f3d 100644
--- a/source/application/main/Classifier.cc
+++ b/source/application/main/Classifier.cc
@@ -28,69 +28,52 @@ namespace arm {
namespace app {
template<typename T>
- bool Classifier::_GetTopNResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- uint32_t topNCount,
- const std::vector <std::string>& labels)
- {
- std::set<std::pair<T, uint32_t>> sortedSet;
-
- /* NOTE: inputVec's size verification against labels should be
- * checked by the calling/public function. */
- T* tensorData = tflite::GetTensorData<T>(tensor);
-
- /* Set initial elements. */
- for (uint32_t i = 0; i < topNCount; ++i) {
- sortedSet.insert({tensorData[i], i});
- }
-
- /* Initialise iterator. */
- auto setFwdIter = sortedSet.begin();
-
- /* Scan through the rest of elements with compare operations. */
- for (uint32_t i = topNCount; i < labels.size(); ++i) {
- if (setFwdIter->first < tensorData[i]) {
- sortedSet.erase(*setFwdIter);
- sortedSet.insert({tensorData[i], i});
- setFwdIter = sortedSet.begin();
- }
- }
-
- /* Final results' container. */
- vecResults = std::vector<ClassificationResult>(topNCount);
+ void SetVectorResults(std::set<std::pair<T, uint32_t>>& topNSet,
+ std::vector<ClassificationResult>& vecResults,
+ TfLiteTensor* tensor,
+ const std::vector <std::string>& labels) {
/* For getting the floating point values, we need quantization parameters. */
QuantParams quantParams = GetTensorQuantParams(tensor);
/* Reset the iterator to the largest element - use reverse iterator. */
- auto setRevIter = sortedSet.rbegin();
-
- /* Populate results
- * Note: we could combine this loop with the loop above, but that
- * would, involve more multiplications and other operations.
- **/
- for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) {
- double score = static_cast<int> (setRevIter->first);
- vecResults[i].m_normalisedVal = quantParams.scale *
- (score - quantParams.offset);
- vecResults[i].m_label = labels[setRevIter->second];
- vecResults[i].m_labelIdx = setRevIter->second;
+ auto topNIter = topNSet.rbegin();
+ for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) {
+ T score = topNIter->first;
+ vecResults[i].m_normalisedVal = quantParams.scale * (score - quantParams.offset);
+ vecResults[i].m_label = labels[topNIter->second];
+ vecResults[i].m_labelIdx = topNIter->second;
}
- return true;
}
template<>
- bool Classifier::_GetTopNResults<float>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- uint32_t topNCount,
- const std::vector <std::string>& labels)
+ void SetVectorResults<float>(std::set<std::pair<float, uint32_t>>& topNSet,
+ std::vector<ClassificationResult>& vecResults,
+ TfLiteTensor* tensor,
+ const std::vector <std::string>& labels) {
+ UNUSED(tensor);
+ /* Reset the iterator to the largest element - use reverse iterator. */
+ auto topNIter = topNSet.rbegin();
+ for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) {
+ vecResults[i].m_normalisedVal = topNIter->first;
+ vecResults[i].m_label = labels[topNIter->second];
+ vecResults[i].m_labelIdx = topNIter->second;
+ }
+
+ }
+
+ template<typename T>
+ bool Classifier::GetTopNResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ uint32_t topNCount,
+ const std::vector <std::string>& labels)
{
- std::set<std::pair<float, uint32_t>> sortedSet;
+ std::set<std::pair<T, uint32_t>> sortedSet;
/* NOTE: inputVec's size verification against labels should be
* checked by the calling/public function. */
- float* tensorData = tflite::GetTensorData<float>(tensor);
+ T* tensorData = tflite::GetTensorData<T>(tensor);
/* Set initial elements. */
for (uint32_t i = 0; i < topNCount; ++i) {
@@ -112,29 +95,18 @@ namespace app {
/* Final results' container. */
vecResults = std::vector<ClassificationResult>(topNCount);
- /* Reset the iterator to the largest element - use reverse iterator. */
- auto setRevIter = sortedSet.rbegin();
-
- /* Populate results
- * Note: we could combine this loop with the loop above, but that
- * would, involve more multiplications and other operations.
- **/
- for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) {
- vecResults[i].m_normalisedVal = setRevIter->first;
- vecResults[i].m_label = labels[setRevIter->second];
- vecResults[i].m_labelIdx = setRevIter->second;
- }
+ SetVectorResults<T>(sortedSet, vecResults, tensor, labels);
return true;
}
- template bool Classifier::_GetTopNResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- uint32_t topNCount, const std::vector <std::string>& labels);
+ template bool Classifier::GetTopNResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ uint32_t topNCount, const std::vector <std::string>& labels);
- template bool Classifier::_GetTopNResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- uint32_t topNCount, const std::vector <std::string>& labels);
+ template bool Classifier::GetTopNResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ uint32_t topNCount, const std::vector <std::string>& labels);
bool Classifier::GetClassificationResults(
TfLiteTensor* outputTensor,
@@ -158,6 +130,9 @@ namespace app {
} else if (totalOutputSize != labels.size()) {
printf_err("Output size doesn't match the labels' size\n");
return false;
+ } else if (topNCount == 0) {
+ printf_err("Top N results cannot be zero\n");
+ return false;
}
bool resultState;
@@ -166,13 +141,13 @@ namespace app {
/* Get the top N results. */
switch (outputTensor->type) {
case kTfLiteUInt8:
- resultState = _GetTopNResults<uint8_t>(outputTensor, vecResults, topNCount, labels);
+ resultState = GetTopNResults<uint8_t>(outputTensor, vecResults, topNCount, labels);
break;
case kTfLiteInt8:
- resultState = _GetTopNResults<int8_t>(outputTensor, vecResults, topNCount, labels);
+ resultState = GetTopNResults<int8_t>(outputTensor, vecResults, topNCount, labels);
break;
case kTfLiteFloat32:
- resultState = _GetTopNResults<float>(outputTensor, vecResults, topNCount, labels);
+ resultState = GetTopNResults<float>(outputTensor, vecResults, topNCount, labels);
break;
default:
printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type));
@@ -180,7 +155,7 @@ namespace app {
}
if (!resultState) {
- printf_err("Failed to get sorted set\n");
+ printf_err("Failed to get top N results set\n");
return false;
}