diff options
Diffstat (limited to 'source/use_case/asr/src/Wav2LetterPostprocess.cc')
-rw-r--r-- | source/use_case/asr/src/Wav2LetterPostprocess.cc | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc index e3e1999..42f434e 100644 --- a/source/use_case/asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc @@ -24,7 +24,7 @@ namespace arm { namespace app { - ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results, const uint32_t outputContextLen, const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx @@ -38,11 +38,11 @@ namespace app { m_blankTokenIdx(blankTokenIdx), m_reductionAxisIdx(reductionAxisIdx) { - this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); } - bool ASRPostProcess::DoPostProcess() + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { @@ -51,7 +51,7 @@ namespace app { /* Irrespective of tensor type, we use unsigned "byte" */ auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor); - const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { @@ -79,7 +79,7 @@ namespace app { return true; } - bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -101,7 +101,7 @@ namespace app { return true; } - uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: @@ -120,7 +120,7 @@ namespace app { return 0; } - bool ASRPostProcess::EraseSectionsRowWise( + bool AsrPostProcess::EraseSectionsRowWise( uint8_t* ptrData, const uint32_t strideSzBytes, const bool lastIteration) @@ -157,7 +157,7 @@ namespace app { return true; } - uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model) + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) { TfLiteTensor* inputTensor = model.GetInputTensor(0); const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); @@ -168,21 +168,23 @@ namespace app { return inputRows; } - uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) { const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); if (outputRows == 0) { printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", Wav2LetterModel::ms_outputRowsIdx); } + + /* Watching for underflow. */ int innerLen = (outputRows - (2 * outputCtxLen)); return std::max(innerLen, 0); } - uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) { - const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model); + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; |