From b40ecf8522052809d2351677a96195d69e4d0c16 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Fri, 22 Apr 2022 16:14:57 +0100 Subject: MLECO-3174: Minor refactoring to implemented use case APIS Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc --- source/use_case/asr/src/AsrClassifier.cc | 196 ++++++++++++++++--------------- 1 file changed, 102 insertions(+), 94 deletions(-) (limited to 'source/use_case/asr/src/AsrClassifier.cc') diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc index 84e66b7..4ba8c7b 100644 --- a/source/use_case/asr/src/AsrClassifier.cc +++ b/source/use_case/asr/src/AsrClassifier.cc @@ -20,117 +20,125 @@ #include "TensorFlowLiteMicro.hpp" #include "Wav2LetterModel.hpp" -template -bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor, - std::vector& vecResults, - const std::vector & labels, double scale, double zeroPoint) -{ - const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx]; - const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]; - - if (nLetters != labels.size()) { - printf("Output size doesn't match the labels' size\n"); - return false; - } +namespace arm { +namespace app { + + template + bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint) + { + const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx]; + const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx]; + + if (nLetters != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } - /* NOTE: tensor's size verification against labels should be - * checked by the calling/public function. */ - if (nLetters < 1) { - return false; - } + /* NOTE: tensor's size verification against labels should be + * checked by the calling/public function. */ + if (nLetters < 1) { + return false; + } - /* Final results' container. */ - vecResults = std::vector(nElems); + /* Final results' container. */ + vecResults = std::vector(nElems); - T* tensorData = tflite::GetTensorData(tensor); + T* tensorData = tflite::GetTensorData(tensor); - /* Get the top 1 results. */ - for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { - std::pair top_1 = std::make_pair(tensorData[row + 0], 0); + /* Get the top 1 results. */ + for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { + std::pair top_1 = std::make_pair(tensorData[row + 0], 0); - for (uint32_t j = 1; j < nLetters; ++j) { - if (top_1.first < tensorData[row + j]) { - top_1.first = tensorData[row + j]; - top_1.second = j; + for (uint32_t j = 1; j < nLetters; ++j) { + if (top_1.first < tensorData[row + j]) { + top_1.first = tensorData[row + j]; + top_1.second = j; + } } + + double score = static_cast (top_1.first); + vecResults[i].m_normalisedVal = scale * (score - zeroPoint); + vecResults[i].m_label = labels[top_1.second]; + vecResults[i].m_labelIdx = top_1.second; } - double score = static_cast (top_1.first); - vecResults[i].m_normalisedVal = scale * (score - zeroPoint); - vecResults[i].m_label = labels[top_1.second]; - vecResults[i].m_labelIdx = top_1.second; + return true; } - - return true; -} -template bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor, - std::vector& vecResults, - const std::vector & labels, double scale, double zeroPoint); -template bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor, - std::vector& vecResults, - const std::vector & labels, double scale, double zeroPoint); - -bool arm::app::AsrClassifier::GetClassificationResults( + template bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, + double scale, double zeroPoint); + template bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, + double scale, double zeroPoint); + + bool AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, const std::vector & labels, uint32_t topNCount, bool use_softmax) -{ - UNUSED(use_softmax); - vecResults.clear(); + { + UNUSED(use_softmax); + vecResults.clear(); - constexpr int minTensorDims = static_cast( - (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)? - arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx); + constexpr int minTensorDims = static_cast( + (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)? + Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx); - constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx; + constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx; - /* Sanity checks. */ - if (outputTensor == nullptr) { - printf_err("Output vector is null pointer.\n"); - return false; - } else if (outputTensor->dims->size < minTensorDims) { - printf_err("Output tensor expected to be %dD\n", minTensorDims); - return false; - } else if (static_cast(outputTensor->dims->data[outColsIdx]) < topNCount) { - printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount); - return false; - } else if (static_cast(outputTensor->dims->data[outColsIdx]) != labels.size()) { - printf("Output size doesn't match the labels' size\n"); - return false; - } + /* Sanity checks. */ + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } else if (outputTensor->dims->size < minTensorDims) { + printf_err("Output tensor expected to be %dD\n", minTensorDims); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) < topNCount) { + printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } - if (topNCount != 1) { - warn("TopNCount value ignored in this implementation\n"); - } + if (topNCount != 1) { + warn("TopNCount value ignored in this implementation\n"); + } - /* To return the floating point values, we need quantization parameters. */ - QuantParams quantParams = GetTensorQuantParams(outputTensor); - - bool resultState; - - switch (outputTensor->type) { - case kTfLiteUInt8: - resultState = this->GetTopResults( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); - break; - case kTfLiteInt8: - resultState = this->GetTopResults( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); - break; - default: - printf_err("Tensor type %s not supported by classifier\n", - TfLiteTypeGetName(outputTensor->type)); + /* To return the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + bool resultState; + + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = this->GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + case kTfLiteInt8: + resultState = this->GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); return false; - } + } - if (!resultState) { - printf_err("Failed to get sorted set\n"); - return false; - } + return true; + } - return true; -} \ No newline at end of file +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file -- cgit v1.2.1