From aa4bcb14d0cbee910331545dd2fc086b58c37170 Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Fri, 6 May 2022 09:13:03 +0100 Subject: MLECO-3183: Refactoring application sources Platform agnostic application sources are moved into application api module with their own independent CMake projects. Changes for MLECO-3080 also included - they create CMake projects individial API's (again, platform agnostic) that dependent on the common logic. The API for KWS_API "joint" API has been removed and now the use case relies on individual KWS, and ASR API libraries. Change-Id: I1f7748dc767abb3904634a04e0991b74ac7b756d Signed-off-by: Kshitij Sisodia --- .../api/use_case/asr/src/Wav2LetterPostprocess.cc | 214 +++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc (limited to 'source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc') diff --git a/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc b/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc new file mode 100644 index 0000000..00e689b --- /dev/null +++ b/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPostprocess.hpp" + +#include "Wav2LetterModel.hpp" +#include "log_macros.h" + +#include + +namespace arm { +namespace app { + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, std::vector& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), + m_countIterations(0), + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } + + bool AsrPostProcess::DoPostProcess() + { + /* Basic checks. */ + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { + return false; + } + + /* Irrespective of tensor type, we use unsigned "byte" */ + auto* ptrData = tflite::GetTensorData(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); + + /* Other sanity checks. */ + if (0 == elemSz) { + printf_err("Tensor type not supported for post processing\n"); + return false; + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { + printf_err("Insufficient number of tensor bytes\n"); + return false; + } + + /* Which axis do we need to process? */ + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; + default: + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; + } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); + + return true; + } + + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const + { + if (nullptr == tensor) { + return false; + } + + if (static_cast(axisIdx) >= tensor->dims->size) { + printf_err("Invalid axis index: %" PRIu32 "; Max: %d\n", + axisIdx, tensor->dims->size); + return false; + } + + if (static_cast(this->m_totalLen) != + tensor->dims->data[axisIdx]) { + printf_err("Unexpected tensor dimension for axis %" PRIu32", got %d.\n", + axisIdx, tensor->dims->data[axisIdx]); + return false; + } + + return true; + } + + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) + { + switch(tensor->type) { + case kTfLiteUInt8: + case kTfLiteInt8: + return 1; + case kTfLiteInt16: + return 2; + case kTfLiteInt32: + case kTfLiteFloat32: + return 4; + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return 0; + } + + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) + { + /* In this case, the "zero-ing" is quite simple as the region + * to be zeroed sits in contiguous memory (row-major). */ + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; + + /* Erase left context? */ + if (this->m_countIterations > 0) { + /* Set output of each classification window to the blank token. */ + std::memset(ptrData, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { + ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; + } + } + + /* Erase right context? */ + if (false == lastIteration) { + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); + /* Set output of each classification window to the blank token. */ + std::memset(rightCtxPtr, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { + rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; + } + } + + if (lastIteration) { + this->m_countIterations = 0; + } else { + ++this->m_countIterations; + } + + return true; + } + + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast(inputRows) / + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen) / inOutRowRatio); + } + +} /* namespace app */ +} /* namespace arm */ -- cgit v1.2.1