diff options
Diffstat (limited to 'source/use_case')
18 files changed, 329 insertions, 317 deletions
diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp index 67a200e..a07a721 100644 --- a/source/use_case/asr/include/AsrClassifier.hpp +++ b/source/use_case/asr/include/AsrClassifier.hpp @@ -35,10 +35,10 @@ namespace app { * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ - bool GetClassificationResults( - TfLiteTensor* outputTensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax = false) override; + bool GetClassificationResults(TfLiteTensor* outputTensor, + std::vector<ClassificationResult>& vecResults, + const std::vector<std::string>& labels, + uint32_t topNCount, bool use_softmax = false) override; private: /** @@ -54,7 +54,7 @@ namespace app { template<typename T> bool GetTopResults(TfLiteTensor* tensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); + const std::vector<std::string>& labels, double scale, double zeroPoint); }; } /* namespace app */ diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp index 895df2b..0078e44 100644 --- a/source/use_case/asr/include/Wav2LetterModel.hpp +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -36,6 +36,7 @@ namespace app { static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + /* Model specific constants. */ static constexpr uint32_t ms_blankTokenIdx = 28; static constexpr uint32_t ms_numMfccFeatures = 13; diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp index 45defa5..446014d 100644 --- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -30,23 +30,24 @@ namespace app { * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class ASRPostProcess : public BasePostProcess { + class AsrPostProcess : public BasePostProcess { public: bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ /** * @brief Constructor - * @param[in] outputTensor Pointer to the output Tensor. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in/out] result Vector of classification results to store decoded outputs. * @param[in] outputContextLen Left/right context length for output tensor. * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, - const std::vector<std::string>& labels, asr::ResultVec& result, - uint32_t outputContextLen, - uint32_t blankTokenIdx, uint32_t reductionAxis); + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector<std::string>& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** * @brief Should perform post-processing of the result of inference then diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp index 8c12b3d..dc9a415 100644 --- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -31,22 +31,22 @@ namespace app { * for ASR. */ using AudioWindow = audio::SlidingWindow<const int16_t>; - class ASRPreProcess : public BasePreProcess { + class AsrPreProcess : public BasePreProcess { public: /** * @brief Constructor. * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. - * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated - * for an inference. */ - ASRPreProcess(TfLiteTensor* inputTensor, - uint32_t numMfccFeatures, - uint32_t audioWindowLen, - uint32_t mfccWindowLen, - uint32_t mfccWindowStride); + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This @@ -130,9 +130,9 @@ namespace app { } /* Populate. */ - T * outputBufMfcc = outputBuf; - T * outputBufD1 = outputBuf + this->m_numMfccFeats; - T * outputBufD2 = outputBufD1 + this->m_numMfccFeats; + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ const float minVal = std::numeric_limits<T>::min(); @@ -141,13 +141,13 @@ namespace app { /* Need to transpose while copying and concatenating the tensor. */ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc index 84e66b7..4ba8c7b 100644 --- a/source/use_case/asr/src/AsrClassifier.cc +++ b/source/use_case/asr/src/AsrClassifier.cc @@ -20,117 +20,125 @@ #include "TensorFlowLiteMicro.hpp" #include "Wav2LetterModel.hpp" -template<typename T> -bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint) -{ - const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx]; - const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]; - - if (nLetters != labels.size()) { - printf("Output size doesn't match the labels' size\n"); - return false; - } +namespace arm { +namespace app { + + template<typename T> + bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, double scale, double zeroPoint) + { + const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx]; + const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx]; + + if (nLetters != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } - /* NOTE: tensor's size verification against labels should be - * checked by the calling/public function. */ - if (nLetters < 1) { - return false; - } + /* NOTE: tensor's size verification against labels should be + * checked by the calling/public function. */ + if (nLetters < 1) { + return false; + } - /* Final results' container. */ - vecResults = std::vector<ClassificationResult>(nElems); + /* Final results' container. */ + vecResults = std::vector<ClassificationResult>(nElems); - T* tensorData = tflite::GetTensorData<T>(tensor); + T* tensorData = tflite::GetTensorData<T>(tensor); - /* Get the top 1 results. */ - for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { - std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0); + /* Get the top 1 results. */ + for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { + std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0); - for (uint32_t j = 1; j < nLetters; ++j) { - if (top_1.first < tensorData[row + j]) { - top_1.first = tensorData[row + j]; - top_1.second = j; + for (uint32_t j = 1; j < nLetters; ++j) { + if (top_1.first < tensorData[row + j]) { + top_1.first = tensorData[row + j]; + top_1.second = j; + } } + + double score = static_cast<int> (top_1.first); + vecResults[i].m_normalisedVal = scale * (score - zeroPoint); + vecResults[i].m_label = labels[top_1.second]; + vecResults[i].m_labelIdx = top_1.second; } - double score = static_cast<int> (top_1.first); - vecResults[i].m_normalisedVal = scale * (score - zeroPoint); - vecResults[i].m_label = labels[top_1.second]; - vecResults[i].m_labelIdx = top_1.second; + return true; } - - return true; -} -template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); -template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); - -bool arm::app::AsrClassifier::GetClassificationResults( + template bool AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, + double scale, double zeroPoint); + template bool AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, + double scale, double zeroPoint); + + bool AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax) -{ - UNUSED(use_softmax); - vecResults.clear(); + { + UNUSED(use_softmax); + vecResults.clear(); - constexpr int minTensorDims = static_cast<int>( - (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)? - arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx); + constexpr int minTensorDims = static_cast<int>( + (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)? + Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx); - constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx; + constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx; - /* Sanity checks. */ - if (outputTensor == nullptr) { - printf_err("Output vector is null pointer.\n"); - return false; - } else if (outputTensor->dims->size < minTensorDims) { - printf_err("Output tensor expected to be %dD\n", minTensorDims); - return false; - } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) { - printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount); - return false; - } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) { - printf("Output size doesn't match the labels' size\n"); - return false; - } + /* Sanity checks. */ + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } else if (outputTensor->dims->size < minTensorDims) { + printf_err("Output tensor expected to be %dD\n", minTensorDims); + return false; + } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) { + printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount); + return false; + } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } - if (topNCount != 1) { - warn("TopNCount value ignored in this implementation\n"); - } + if (topNCount != 1) { + warn("TopNCount value ignored in this implementation\n"); + } - /* To return the floating point values, we need quantization parameters. */ - QuantParams quantParams = GetTensorQuantParams(outputTensor); - - bool resultState; - - switch (outputTensor->type) { - case kTfLiteUInt8: - resultState = this->GetTopResults<uint8_t>( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); - break; - case kTfLiteInt8: - resultState = this->GetTopResults<int8_t>( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); - break; - default: - printf_err("Tensor type %s not supported by classifier\n", - TfLiteTypeGetName(outputTensor->type)); + /* To return the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + bool resultState; + + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = this->GetTopResults<uint8_t>( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + case kTfLiteInt8: + resultState = this->GetTopResults<int8_t>( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); return false; - } + } - if (!resultState) { - printf_err("Failed to get sorted set\n"); - return false; - } + return true; + } - return true; -}
\ No newline at end of file +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 7fe959b..850bdc2 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -33,9 +33,9 @@ namespace arm { namespace app { /** - * @brief Presents ASR inference results. - * @param[in] results Vector of ASR classification results to be displayed. - * @return true if successful, false otherwise. + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results); @@ -63,6 +63,9 @@ namespace app { return false; } + TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + /* Get input shape. Dimensions of the tensor should have been verified by * the callee. */ TfLiteIntArray* inputShape = model.GetInputShape(0); @@ -78,19 +81,19 @@ namespace app { const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); /* Set up pre and post-processing objects. */ - ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride); + AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + mfccFrameLen, mfccFrameStride); std::vector<ClassificationResult> singleInfResult; - const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen); - ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"), - model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"), + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); + AsrPostProcess postProcess = AsrPostProcess( + outputTensor, ctx.Get<AsrClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), singleInfResult, outputCtxLen, Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx ); - UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model); - /* Loop to process audio clips. */ do { hal_lcd_clear(COLOR_BLACK); @@ -147,16 +150,20 @@ namespace app { static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ - runner.PreProcess(inferenceWindow, inferenceWindowLen); + if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) { + printf_err("Pre-processing failed."); + return false; + } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); + /* Post processing needs to know if we are on the last audio window. */ postProcess.m_lastIteration = !audioDataSlider.HasNext(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -166,7 +173,6 @@ namespace app { audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); armDumpTensor(outputTensor, outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc index e3e1999..42f434e 100644 --- a/source/use_case/asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc @@ -24,7 +24,7 @@ namespace arm { namespace app { - ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results, const uint32_t outputContextLen, const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx @@ -38,11 +38,11 @@ namespace app { m_blankTokenIdx(blankTokenIdx), m_reductionAxisIdx(reductionAxisIdx) { - this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); } - bool ASRPostProcess::DoPostProcess() + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { @@ -51,7 +51,7 @@ namespace app { /* Irrespective of tensor type, we use unsigned "byte" */ auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor); - const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { @@ -79,7 +79,7 @@ namespace app { return true; } - bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -101,7 +101,7 @@ namespace app { return true; } - uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: @@ -120,7 +120,7 @@ namespace app { return 0; } - bool ASRPostProcess::EraseSectionsRowWise( + bool AsrPostProcess::EraseSectionsRowWise( uint8_t* ptrData, const uint32_t strideSzBytes, const bool lastIteration) @@ -157,7 +157,7 @@ namespace app { return true; } - uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model) + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) { TfLiteTensor* inputTensor = model.GetInputTensor(0); const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); @@ -168,21 +168,23 @@ namespace app { return inputRows; } - uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) { const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); if (outputRows == 0) { printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", Wav2LetterModel::ms_outputRowsIdx); } + + /* Watching for underflow. */ int innerLen = (outputRows - (2 * outputCtxLen)); return std::max(innerLen, 0); } - uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) { - const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model); + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc index 590d08a..92b0631 100644 --- a/source/use_case/asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc @@ -25,9 +25,9 @@ namespace arm { namespace app { - ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, - const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, - const uint32_t mfccWindowStride + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride ): m_mfcc(numMfccFeatures, mfccWindowLen), m_inputTensor(inputTensor), @@ -44,7 +44,7 @@ namespace app { } } - bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>( static_cast<const int16_t*>(audioData), audioDataLen, @@ -82,7 +82,7 @@ namespace app { } /* Compute first and second order deltas from MFCCs. */ - ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); /* Standardize calculated features. */ this->Standarize(); @@ -112,9 +112,9 @@ namespace app { return false; } - bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc, - Array2d<float>& delta1, - Array2d<float>& delta2) + bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc, + Array2d<float>& delta1, + Array2d<float>& delta2) { const std::vector <float> delta1Coeffs = {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, @@ -167,7 +167,7 @@ namespace app { return true; } - void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec) + void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec) { auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); @@ -186,14 +186,14 @@ namespace app { } } - void ASRPreProcess::Standarize() + void AsrPreProcess::Standarize() { - ASRPreProcess::StandardizeVecF32(this->m_mfccBuf); - ASRPreProcess::StandardizeVecF32(this->m_delta1Buf); - ASRPreProcess::StandardizeVecF32(this->m_delta2Buf); + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float ASRPreProcess::GetQuantElem( + float AsrPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp index 59db4a5..e931b7d 100644 --- a/source/use_case/img_class/include/ImgClassProcessing.hpp +++ b/source/use_case/img_class/include/ImgClassProcessing.hpp @@ -34,9 +34,10 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the Image classification Model object. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] convertToInt8 Should the image be converted to Int8 range. **/ - explicit ImgClassPreProcess(Model* model); + explicit ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8); /** * @brief Should perform pre-processing of 'raw' input image data and load it into @@ -46,6 +47,10 @@ namespace app { * @return true if successful, false otherwise. **/ bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; + bool m_convertToInt8; }; /** @@ -55,29 +60,30 @@ namespace app { */ class ImgClassPostProcess : public BasePostProcess { - private: - Classifier& m_imgClassifier; - const std::vector<std::string>& m_labels; - std::vector<ClassificationResult>& m_results; - public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the the Image classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in] results Vector of classification results to store decoded outputs. **/ - ImgClassPostProcess(Classifier& classifier, Model* model, + ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); /** - * @brief Should perform post-processing of the result of inference then populate + * @brief Should perform post-processing of the result of inference then * populate classification result data for any later use. * @return true if successful, false otherwise. **/ bool DoPostProcess() override; + + private: + TfLiteTensor* m_outputTensor; + Classifier& m_imgClassifier; + const std::vector<std::string>& m_labels; + std::vector<ClassificationResult>& m_results; }; } /* namespace app */ diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc index 6ba88ad..adf9794 100644 --- a/source/use_case/img_class/src/ImgClassProcessing.cc +++ b/source/use_case/img_class/src/ImgClassProcessing.cc @@ -21,50 +21,43 @@ namespace arm { namespace app { - ImgClassPreProcess::ImgClassPreProcess(Model* model) - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + ImgClassPreProcess::ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8) + :m_inputTensor{inputTensor}, + m_convertToInt8{convertToInt8} + {} bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize) { if (data == nullptr) { printf_err("Data pointer is null"); + return false; } auto input = static_cast<const uint8_t*>(data); - TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); - std::memcpy(inputTensor->data.data, input, inputSize); + std::memcpy(this->m_inputTensor->data.data, input, inputSize); debug("Input tensor populated \n"); - if (this->m_model->IsDataSigned()) { - image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + if (this->m_convertToInt8) { + image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes); } return true; } - ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model, + ImgClassPostProcess::ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_imgClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_imgClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} bool ImgClassPostProcess::DoPostProcess() { return this->m_imgClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 5, false); } diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index c68d816..5cc3959 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -59,6 +59,7 @@ namespace app { } TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -74,13 +75,12 @@ namespace app { const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; /* Set up pre and post-processing. */ - ImgClassPreProcess preprocess = ImgClassPreProcess(&model); + ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned()); std::vector<ClassificationResult> results; - ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model, - ctx.Get<std::vector<std::string>&>("labels"), results); - - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); + ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor, + ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), + results); do { hal_lcd_clear(COLOR_BLACK); @@ -113,17 +113,18 @@ namespace app { inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(imgSrc, imgSz)) { + if (!preProcess.DoPreProcess(imgSrc, imgSz)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -136,7 +137,6 @@ namespace app { ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp index ddf38c1..d3de3b3 100644 --- a/source/use_case/kws/include/KwsProcessing.hpp +++ b/source/use_case/kws/include/KwsProcessing.hpp @@ -33,18 +33,21 @@ namespace app { * Implements methods declared by BasePreProcess and anything else needed * to populate input tensors ready for inference. */ - class KWSPreProcess : public BasePreProcess { + class KwsPreProcess : public BasePreProcess { public: /** * @brief Constructor - * @param[in] model Pointer to the KWS Model object. - * @param[in] numFeatures How many MFCC features to use. - * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when - * sliding a window through the audio sample. - * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. **/ - explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride); + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); /** * @brief Should perform pre-processing of 'raw' input audio data and load it into @@ -60,8 +63,10 @@ namespace app { size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ const int m_mfccFrameLength; const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ audio::MicroNetKwsMFCC m_mfcc; audio::SlidingWindow<const int16_t> m_mfccSlidingWindow; @@ -99,22 +104,23 @@ namespace app { * Implements methods declared by BasePostProcess and anything else needed * to populate result vector. */ - class KWSPostProcess : public BasePostProcess { + class KwsPostProcess : public BasePostProcess { private: - Classifier& m_kwsClassifier; - const std::vector<std::string>& m_labels; - std::vector<ClassificationResult>& m_results; + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector<std::string>& m_labels; /* KWS Labels. */ + std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */ public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the KWS Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. **/ - KWSPostProcess(Classifier& classifier, Model* model, + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); diff --git a/source/use_case/kws/include/KwsResult.hpp b/source/use_case/kws/include/KwsResult.hpp index 5a26ce1..38f32b4 100644 --- a/source/use_case/kws/include/KwsResult.hpp +++ b/source/use_case/kws/include/KwsResult.hpp @@ -25,7 +25,7 @@ namespace arm { namespace app { namespace kws { - using ResultVec = std::vector < arm::app::ClassificationResult >; + using ResultVec = std::vector<arm::app::ClassificationResult>; /* Structure for holding kws result. */ class KwsResult { diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc index 14f9fce..328709d 100644 --- a/source/use_case/kws/src/KwsProcessing.cc +++ b/source/use_case/kws/src/KwsProcessing.cc @@ -22,22 +22,19 @@ namespace arm { namespace app { - KWSPreProcess::KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride): + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, m_mfccFrameLength{mfccFrameLength}, m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; this->m_mfcc.Init(); - TfLiteIntArray* inputShape = model->GetInputShape(0); - const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - /* Deduce the data length required for 1 inference from the network parameters. */ - this->m_audioDataWindowSize = numMfccFrames * this->m_mfccFrameStride + + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + (this->m_mfccFrameLength - this->m_mfccFrameStride); /* Creating an MFCC feature sliding window for the data required for 1 inference. */ @@ -62,7 +59,7 @@ namespace app { - this->m_numMfccVectorsInAudioStride; /* Construct feature calculation function. */ - this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_model->GetInputTensor(0), + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, this->m_numReusedMfccVectors); if (!this->m_mfccFeatureCalculator) { @@ -70,7 +67,7 @@ namespace app { } } - bool KWSPreProcess::DoPreProcess(const void* data, size_t inputSize) + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) { UNUSED(inputSize); if (data == nullptr) { @@ -116,8 +113,8 @@ namespace app { */ template<class T> std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - KWSPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) { /* Feature cache to be captured by lambda function. */ static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); @@ -149,18 +146,18 @@ namespace app { } template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - KWSPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> - KWSPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); + KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); std::function<void (std::vector<int16_t>&, int, bool, size_t)> - KWSPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; @@ -195,23 +192,19 @@ namespace app { return mfccFeatureCalc; } - KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model, + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_kwsClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} - bool KWSPostProcess::DoPostProcess() + bool KwsPostProcess::DoPostProcess() { return this->m_kwsClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 1, true); } diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index e73a2c3..61c6eb6 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -34,13 +34,12 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - /** * @brief Presents KWS inference results. * @param[in] results Vector of KWS classification results to be displayed. * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results); + static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results); /* KWS inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) @@ -50,6 +49,7 @@ namespace app { const auto mfccFrameLength = ctx.Get<int>("frameLength"); const auto mfccFrameStride = ctx.Get<int>("frameStride"); const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { @@ -61,16 +61,17 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); - + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); return false; } + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -81,22 +82,23 @@ namespace app { /* Get input shape for feature extraction. */ TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t numMfccFeatures = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ const float secondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; /* Set up pre and post-processing. */ - KWSPreProcess preprocess = KWSPreProcess(&model, numMfccFeatures, mfccFrameLength, mfccFrameStride); + KwsPreProcess preProcess = KwsPreProcess(inputTensor, numMfccFeatures, numMfccFrames, + mfccFrameLength, mfccFrameStride); std::vector<ClassificationResult> singleInfResult; - KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model, + KwsPostProcess postProcess = KwsPostProcess(outputTensor, ctx.Get<KwsClassifier &>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), singleInfResult); - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); - + /* Loop to process audio clips. */ do { hal_lcd_clear(COLOR_BLACK); @@ -106,7 +108,7 @@ namespace app { auto audioDataSlider = audio::SlidingWindow<const int16_t>( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - preprocess.m_audioDataWindowSize, preprocess.m_audioDataStride); + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); /* Declare a container to hold results from across the whole audio clip. */ std::vector<kws::KwsResult> finalResults; @@ -123,34 +125,34 @@ namespace app { const int16_t* inferenceWindow = audioDataSlider.Next(); /* The first window does not have cache ready. */ - preprocess.m_audioWindowIndex = audioDataSlider.Index(); + preProcess.m_audioWindowIndex = audioDataSlider.Index(); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } /* Add results from this window to our final results vector. */ finalResults.emplace_back(kws::KwsResult(singleInfResult, - audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride, + audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride, audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - arm::app::DumpTensor(outputTensor); + DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ @@ -174,7 +176,7 @@ namespace app { return true; } - static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results) + static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 30; @@ -187,7 +189,7 @@ namespace app { /* Display each result */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (const auto & result : results) { + for (const auto& result : results) { std::string topKeyword{"<none>"}; float score = 0.f; diff --git a/source/use_case/vww/include/VisualWakeWordProcessing.hpp b/source/use_case/vww/include/VisualWakeWordProcessing.hpp index b1d68ce..bef161f 100644 --- a/source/use_case/vww/include/VisualWakeWordProcessing.hpp +++ b/source/use_case/vww/include/VisualWakeWordProcessing.hpp @@ -34,9 +34,9 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the Image classification Model object. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. **/ - explicit VisualWakeWordPreProcess(Model* model); + explicit VisualWakeWordPreProcess(TfLiteTensor* inputTensor); /** * @brief Should perform pre-processing of 'raw' input image data and load it into @@ -46,6 +46,9 @@ namespace app { * @return true if successful, false otherwise. **/ bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; }; /** @@ -56,6 +59,7 @@ namespace app { class VisualWakeWordPostProcess : public BasePostProcess { private: + TfLiteTensor* m_outputTensor; Classifier& m_vwwClassifier; const std::vector<std::string>& m_labels; std::vector<ClassificationResult>& m_results; @@ -63,19 +67,20 @@ namespace app { public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the VWW classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[out] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] model Pointer to the VWW classification Model object. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[out] results Vector of classification results to store decoded outputs. **/ - VisualWakeWordPostProcess(Classifier& classifier, Model* model, + VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); /** - * @brief Should perform post-processing of the result of inference then - * populate classification result data for any later use. - * @return true if successful, false otherwise. + * @brief Should perform post-processing of the result of inference then + * populate classification result data for any later use. + * @return true if successful, false otherwise. **/ bool DoPostProcess() override; }; diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc index 7681f89..267e6c4 100644 --- a/source/use_case/vww/src/UseCaseHandler.cc +++ b/source/use_case/vww/src/UseCaseHandler.cc @@ -53,7 +53,7 @@ namespace app { } TfLiteTensor* inputTensor = model.GetInputTensor(0); - + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -75,15 +75,13 @@ namespace app { const uint32_t displayChannels = 3; /* Set up pre and post-processing. */ - VisualWakeWordPreProcess preprocess = VisualWakeWordPreProcess(&model); + VisualWakeWordPreProcess preProcess = VisualWakeWordPreProcess(inputTensor); std::vector<ClassificationResult> results; - VisualWakeWordPostProcess postprocess = VisualWakeWordPostProcess( - ctx.Get<Classifier&>("classifier"), &model, + VisualWakeWordPostProcess postProcess = VisualWakeWordPostProcess(outputTensor, + ctx.Get<Classifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), results); - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); - do { hal_lcd_clear(COLOR_BLACK); @@ -115,17 +113,18 @@ namespace app { inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(imgSrc, imgSz)) { + if (!preProcess.DoPreProcess(imgSrc, imgSz)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -138,7 +137,6 @@ namespace app { ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ diff --git a/source/use_case/vww/src/VisualWakeWordProcessing.cc b/source/use_case/vww/src/VisualWakeWordProcessing.cc index 94eae28..a9863c0 100644 --- a/source/use_case/vww/src/VisualWakeWordProcessing.cc +++ b/source/use_case/vww/src/VisualWakeWordProcessing.cc @@ -22,13 +22,9 @@ namespace arm { namespace app { - VisualWakeWordPreProcess::VisualWakeWordPreProcess(Model* model) - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + VisualWakeWordPreProcess::VisualWakeWordPreProcess(TfLiteTensor* inputTensor) + :m_inputTensor{inputTensor} + {} bool VisualWakeWordPreProcess::DoPreProcess(const void* data, size_t inputSize) { @@ -37,9 +33,8 @@ namespace app { } auto input = static_cast<const uint8_t*>(data); - TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); - auto unsignedDstPtr = static_cast<uint8_t*>(inputTensor->data.data); + auto unsignedDstPtr = static_cast<uint8_t*>(this->m_inputTensor->data.data); /* VWW model has one channel input => Convert image to grayscale here. * We expect images to always be RGB. */ @@ -47,10 +42,10 @@ namespace app { /* VWW model pre-processing is image conversion from uint8 to [0,1] float values, * then quantize them with input quantization info. */ - QuantParams inQuantParams = GetTensorQuantParams(inputTensor); + QuantParams inQuantParams = GetTensorQuantParams(this->m_inputTensor); - auto signedDstPtr = static_cast<int8_t*>(inputTensor->data.data); - for (size_t i = 0; i < inputTensor->bytes; i++) { + auto signedDstPtr = static_cast<int8_t*>(this->m_inputTensor->data.data); + for (size_t i = 0; i < this->m_inputTensor->bytes; i++) { auto i_data_int8 = static_cast<int8_t>( ((static_cast<float>(unsignedDstPtr[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset ); @@ -62,22 +57,18 @@ namespace app { return true; } - VisualWakeWordPostProcess::VisualWakeWordPostProcess(Classifier& classifier, Model* model, + VisualWakeWordPostProcess::VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_vwwClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_vwwClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} bool VisualWakeWordPostProcess::DoPostProcess() { return this->m_vwwClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 1, true); } |