aboutsummaryrefslogtreecommitdiff
path: root/samples/SpeechRecognition/include/Decoder.hpp
diff options
context:
space:
mode:
authorÉanna Ó Catháin <eanna.ocathain@arm.com>2021-04-07 14:35:25 +0100
committerJim Flynn <jim.flynn@arm.com>2021-05-07 09:11:52 +0000
commitc6ab02a626e15b4a12fc09ecd844eb8b95380c3c (patch)
tree9912ed9cdb89cdb24483b22d6621ae30049ae321 /samples/SpeechRecognition/include/Decoder.hpp
parente813d67f86df41a238ff79b5c554ef5027f56576 (diff)
downloadarmnn-c6ab02a626e15b4a12fc09ecd844eb8b95380c3c.tar.gz
MLECO-1252 ASR sample application using the public ArmNN C++ API.
Change-Id: I98cd505b8772a8c8fa88308121bc94135bb45068 Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'samples/SpeechRecognition/include/Decoder.hpp')
-rw-r--r--samples/SpeechRecognition/include/Decoder.hpp63
1 files changed, 63 insertions, 0 deletions
diff --git a/samples/SpeechRecognition/include/Decoder.hpp b/samples/SpeechRecognition/include/Decoder.hpp
new file mode 100644
index 0000000000..69d97ccf64
--- /dev/null
+++ b/samples/SpeechRecognition/include/Decoder.hpp
@@ -0,0 +1,63 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <string>
+#include <map>
+#include <vector>
+#include <algorithm>
+#include <cmath>
+
+# pragma once
+
+namespace asr
+{
+/**
+* @brief Class used to Decode the output of the ASR inference
+*
+*/
+ class Decoder
+ {
+ public:
+ std::map<int, std::string> m_labels;
+ /**
+ * @brief Default constructor
+ * @param[in] labels - map of labels to be used for decoding to text.
+ */
+ Decoder(std::map<int, std::string>& labels);
+
+ /**
+ * @brief Function to decode the output into a text string
+ * @param[in] output - the output vector to decode.
+ */
+ template<typename T>
+ std::string DecodeOutput(std::vector<T>& contextToProcess)
+ {
+ int rowLength = 29;
+
+ std::vector<char> unfilteredText;
+
+ for(int row = 0; row < contextToProcess.size()/rowLength; ++row)
+ {
+ std::vector<int16_t> rowVector;
+ for(int j = 0; j < rowLength; ++j)
+ {
+ rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j]));
+ }
+
+ int max_index = std::distance(rowVector.begin(),std::max_element(rowVector.begin(), rowVector.end()));
+ unfilteredText.emplace_back(this->m_labels.at(max_index)[0]);
+ }
+
+ std::string filteredText = FilterCharacters(unfilteredText);
+ return filteredText;
+ }
+
+ /**
+ * @brief Function to filter out unwanted characters
+ * @param[in] unfiltered - the unfiltered output to be processed.
+ */
+ std::string FilterCharacters(std::vector<char>& unfiltered);
+ };
+} // namespace asr