// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include #include #include #include #include # pragma once namespace asr { /** * @brief Class used to Decode the output of the ASR inference * */ class Decoder { public: std::map m_labels; /** * @brief Default constructor * @param[in] labels - map of labels to be used for decoding to text. */ Decoder(std::map& labels); /** * @brief Function to decode the output into a text string * @param[in] output - the output vector to decode. */ template std::string DecodeOutput(std::vector& contextToProcess) { int rowLength = 29; std::vector unfilteredText; for(int row = 0; row < contextToProcess.size()/rowLength; ++row) { std::vector rowVector; for(int j = 0; j < rowLength; ++j) { rowVector.emplace_back(static_cast(contextToProcess[row * rowLength + j])); } int max_index = std::distance(rowVector.begin(),std::max_element(rowVector.begin(), rowVector.end())); unfilteredText.emplace_back(this->m_labels.at(max_index)[0]); } std::string filteredText = FilterCharacters(unfilteredText); return filteredText; } /** * @brief Function to filter out unwanted characters * @param[in] unfiltered - the unfiltered output to be processed. */ std::string FilterCharacters(std::vector& unfiltered); }; } // namespace asr